[ALL] Merge master (#447)
This commit is contained in:
		
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -1,6 +1,9 @@
 | 
				
			|||||||
 | 
					build/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__pycache__
 | 
					__pycache__
 | 
				
			||||||
.pytest_cache
 | 
					.pytest_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
python/build/
 | 
					python/build/
 | 
				
			||||||
python/triton.egg-info/
 | 
					python/triton.egg-info/
 | 
				
			||||||
 | 
					python/triton/_C/libtriton.pyd
 | 
				
			||||||
python/triton/_C/libtriton.so
 | 
					python/triton/_C/libtriton.so
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										3
									
								
								.gitmodules
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								.gitmodules
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
				
			|||||||
 | 
					[submodule "deps/dlfcn-win32"]
 | 
				
			||||||
 | 
						path = deps/dlfcn-win32
 | 
				
			||||||
 | 
						url = https://github.com/dlfcn-win32/dlfcn-win32.git
 | 
				
			||||||
@@ -1,6 +1,8 @@
 | 
				
			|||||||
cmake_minimum_required(VERSION 3.6)
 | 
					cmake_minimum_required(VERSION 3.6)
 | 
				
			||||||
include(ExternalProject)
 | 
					include(ExternalProject)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					set(CMAKE_CXX_STANDARD 17)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if(NOT TRITON_LLVM_BUILD_DIR)
 | 
					if(NOT TRITON_LLVM_BUILD_DIR)
 | 
				
			||||||
    set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
 | 
					    set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
 | 
				
			||||||
endif()
 | 
					endif()
 | 
				
			||||||
@@ -8,7 +10,9 @@ endif()
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
project(triton)
 | 
					project(triton)
 | 
				
			||||||
include(CTest)
 | 
					include(CTest)
 | 
				
			||||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
 | 
					if(NOT WIN32)
 | 
				
			||||||
 | 
					  list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
 | 
				
			||||||
 | 
					endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Options
 | 
					# Options
 | 
				
			||||||
option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
 | 
					option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
 | 
				
			||||||
@@ -20,10 +24,19 @@ if(NOT CMAKE_BUILD_TYPE)
 | 
				
			|||||||
  set(CMAKE_BUILD_TYPE "Release")
 | 
					  set(CMAKE_BUILD_TYPE "Release")
 | 
				
			||||||
endif()
 | 
					endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
find_library(TERMINFO_LIBRARY tinfo)
 | 
					if(NOT WIN32)
 | 
				
			||||||
 | 
					    find_library(TERMINFO_LIBRARY tinfo)
 | 
				
			||||||
 | 
					endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Compiler flags
 | 
					# Compiler flags
 | 
				
			||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
 | 
					include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if(WIN32)
 | 
				
			||||||
 | 
					    SET(BUILD_SHARED_LIBS OFF)
 | 
				
			||||||
 | 
					    include_directories(${CMAKE_CURRENT_SOURCE_DIR}/deps/dlfcn-win32/src)
 | 
				
			||||||
 | 
					    add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
 | 
				
			||||||
 | 
					endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS  -std=gnu++17")
 | 
					set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS  -std=gnu++17")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -31,7 +44,20 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS  -std=gnu++17")
 | 
				
			|||||||
# LLVM
 | 
					# LLVM
 | 
				
			||||||
##########
 | 
					##########
 | 
				
			||||||
if("${LLVM_LIBRARY_DIR}" STREQUAL "")
 | 
					if("${LLVM_LIBRARY_DIR}" STREQUAL "")
 | 
				
			||||||
    find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu")
 | 
					    if(WIN32)
 | 
				
			||||||
 | 
					      find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      include_directories(${LLVM_INCLUDE_DIRS})
 | 
				
			||||||
 | 
					      separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS})
 | 
				
			||||||
 | 
					      add_definitions(${LLVM_DEFINITIONS_LIST})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      llvm_map_components_to_libnames(LLVM_LIBRARIES support core
 | 
				
			||||||
 | 
					        NVPTXInfo nvptxcodegen
 | 
				
			||||||
 | 
					        AMDGPUInfo AMDGPUcodegen
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					    else()
 | 
				
			||||||
 | 
					      find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu")
 | 
				
			||||||
 | 
					    endif()
 | 
				
			||||||
    message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
 | 
					    message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
 | 
				
			||||||
    if(APPLE)
 | 
					    if(APPLE)
 | 
				
			||||||
      set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14")
 | 
					      set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14")
 | 
				
			||||||
@@ -108,12 +134,25 @@ endif()
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# Triton
 | 
					# Triton
 | 
				
			||||||
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
 | 
					file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
 | 
				
			||||||
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
 | 
					if (WIN32 AND BUILD_PYTHON_MODULE)
 | 
				
			||||||
 | 
					    find_package(Python3 REQUIRED COMPONENTS Development)
 | 
				
			||||||
 | 
					    Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
 | 
				
			||||||
 | 
					    set_target_properties(triton PROPERTIES SUFFIX ".pyd")
 | 
				
			||||||
 | 
					    set_target_properties(triton PROPERTIES PREFIX "lib")
 | 
				
			||||||
 | 
					else()
 | 
				
			||||||
 | 
					    add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
 | 
				
			||||||
 | 
					endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
 | 
					target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
 | 
				
			||||||
target_link_libraries(triton ${LLVM_LIBRARIES} z ${TERMINFO_LIBRARY})
 | 
					
 | 
				
			||||||
 | 
					if(WIN32)
 | 
				
			||||||
 | 
					    target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
 | 
				
			||||||
 | 
					else()
 | 
				
			||||||
 | 
					    target_link_libraries(triton ${LLVM_LIBRARIES} z ${TERMINFO_LIBRARY})
 | 
				
			||||||
 | 
					endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if(BUILD_PYTHON_MODULE)
 | 
					if(BUILD_PYTHON_MODULE AND NOT WIN32)
 | 
				
			||||||
    set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
 | 
					    set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
 | 
				
			||||||
    # Check if the platform is MacOS
 | 
					    # Check if the platform is MacOS
 | 
				
			||||||
    if(APPLE)
 | 
					    if(APPLE)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										1
									
								
								deps/dlfcn-win32
									
									
									
									
										vendored
									
									
										Submodule
									
								
							
							
								
								
								
								
								
							
						
						
									
										1
									
								
								deps/dlfcn-win32
									
									
									
									
										vendored
									
									
										Submodule
									
								
							 Submodule deps/dlfcn-win32 added at 522c301ec3
									
								
							@@ -13,6 +13,14 @@ namespace tools
 | 
				
			|||||||
{
 | 
					{
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#ifdef _WIN32
 | 
				
			||||||
 | 
					#define popen _popen
 | 
				
			||||||
 | 
					#define pclose _pclose
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#ifndef WEXITSTATUS
 | 
				
			||||||
 | 
					#define WEXITSTATUS(stat_val) ((unsigned)(stat_val) & 255)
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
int exec(const std::string& cmd, std::string& result) {
 | 
					int exec(const std::string& cmd, std::string& result) {
 | 
				
			||||||
  char buffer[128];
 | 
					  char buffer[128];
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -33,19 +33,10 @@ namespace tools
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    inline std::string getenv(const char * name)
 | 
					    inline std::string getenv(const char * name)
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
        #ifdef _MSC_VER
 | 
					        const char * cstr = std::getenv(name);
 | 
				
			||||||
            char* cache_path = 0;
 | 
					 | 
				
			||||||
            std::size_t sz = 0;
 | 
					 | 
				
			||||||
            _dupenv_s(&cache_path, &sz, name);
 | 
					 | 
				
			||||||
        #else
 | 
					 | 
				
			||||||
            const char * cstr = std::getenv(name);
 | 
					 | 
				
			||||||
        #endif
 | 
					 | 
				
			||||||
        if(!cstr)
 | 
					        if(!cstr)
 | 
				
			||||||
            return "";
 | 
					            return "";
 | 
				
			||||||
        std::string result(cstr);
 | 
					        std::string result(cstr);
 | 
				
			||||||
        #ifdef _MSC_VER
 | 
					 | 
				
			||||||
            free(cache_path);
 | 
					 | 
				
			||||||
        #endif
 | 
					 | 
				
			||||||
        return result;
 | 
					        return result;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -449,18 +449,18 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0
 | 
				
			|||||||
  "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t"
 | 
					  "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t"
 | 
				
			||||||
  "}", "=r,=r,r", false);
 | 
					  "}", "=r,=r,r", false);
 | 
				
			||||||
  Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
 | 
					  Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
 | 
				
			||||||
  packed_in = insert_elt(packed_in, in0, (int)0);
 | 
					  packed_in = insert_elt(packed_in, in0, (uint64_t)0);
 | 
				
			||||||
  packed_in = insert_elt(packed_in, in1, (int)1);
 | 
					  packed_in = insert_elt(packed_in, in1, (uint64_t)1);
 | 
				
			||||||
  packed_in = insert_elt(packed_in, in2, (int)2);
 | 
					  packed_in = insert_elt(packed_in, in2, (uint64_t)2);
 | 
				
			||||||
  packed_in = insert_elt(packed_in, in3, (int)3);
 | 
					  packed_in = insert_elt(packed_in, in3, (uint64_t)3);
 | 
				
			||||||
  Value *in = bit_cast(packed_in, i32_ty);
 | 
					  Value *in = bit_cast(packed_in, i32_ty);
 | 
				
			||||||
  Value *ret = call(ptx, {in});
 | 
					  Value *ret = call(ptx, {in});
 | 
				
			||||||
  Value *packed_ret0 = extract_val(ret, {0});
 | 
					  Value *packed_ret0 = extract_val(ret, {0});
 | 
				
			||||||
  Value *packed_ret1 = extract_val(ret, {1});
 | 
					  Value *packed_ret1 = extract_val(ret, {1});
 | 
				
			||||||
  Value *ret0 = extract_elt(packed_ret0, (int)0);
 | 
					  Value *ret0 = extract_elt(packed_ret0, (uint64_t)0);
 | 
				
			||||||
  Value *ret1 = extract_elt(packed_ret0, (int)1);
 | 
					  Value *ret1 = extract_elt(packed_ret0, (uint64_t)1);
 | 
				
			||||||
  Value *ret2 = extract_elt(packed_ret1, (int)0);
 | 
					  Value *ret2 = extract_elt(packed_ret1, (uint64_t)0);
 | 
				
			||||||
  Value *ret3 = extract_elt(packed_ret1, (int)1);
 | 
					  Value *ret3 = extract_elt(packed_ret1, (uint64_t)1);
 | 
				
			||||||
  return std::make_tuple(ret0, ret1, ret2, ret3);
 | 
					  return std::make_tuple(ret0, ret1, ret2, ret3);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -717,11 +717,11 @@ void generator::visit_load_inst(ir::load_inst* x){
 | 
				
			|||||||
    // ---
 | 
					    // ---
 | 
				
			||||||
    // finally call inline ASM
 | 
					    // finally call inline ASM
 | 
				
			||||||
    // ---
 | 
					    // ---
 | 
				
			||||||
    InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
 | 
					    InlineAsm *inlineAsm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
 | 
				
			||||||
    std::vector<Value*> args = {pred, ptr};
 | 
					    std::vector<Value*> args = {pred, ptr};
 | 
				
			||||||
    for(Value *v: others)
 | 
					    for(Value *v: others)
 | 
				
			||||||
        args.push_back(v);
 | 
					        args.push_back(v);
 | 
				
			||||||
    Value *_ret = call(_asm, args);
 | 
					    Value *_ret = call(inlineAsm, args);
 | 
				
			||||||
    // ---
 | 
					    // ---
 | 
				
			||||||
    // extract and store return values
 | 
					    // extract and store return values
 | 
				
			||||||
    // ---
 | 
					    // ---
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -91,9 +91,13 @@ void* dispatch::fname ## _;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
bool dispatch::cuinit(){
 | 
					bool dispatch::cuinit(){
 | 
				
			||||||
  if(cuda_==nullptr){
 | 
					  if(cuda_==nullptr){
 | 
				
			||||||
 | 
					    #ifdef _WIN32
 | 
				
			||||||
 | 
					    cuda_ = dlopen("cudart64_110.dll", RTLD_LAZY);
 | 
				
			||||||
 | 
					    #else
 | 
				
			||||||
    cuda_ = dlopen("libcuda.so", RTLD_LAZY);
 | 
					    cuda_ = dlopen("libcuda.so", RTLD_LAZY);
 | 
				
			||||||
    if(!cuda_)
 | 
					    if(!cuda_)
 | 
				
			||||||
      cuda_ = dlopen("libcuda.so.1", RTLD_LAZY);
 | 
					      cuda_ = dlopen("libcuda.so.1", RTLD_LAZY);
 | 
				
			||||||
 | 
					    #endif
 | 
				
			||||||
    if(!cuda_)
 | 
					    if(!cuda_)
 | 
				
			||||||
      throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH.");
 | 
					      throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH.");
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@@ -176,8 +180,13 @@ CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent)
 | 
				
			|||||||
 * NVML
 | 
					 * NVML
 | 
				
			||||||
 * ------------------- */
 | 
					 * ------------------- */
 | 
				
			||||||
bool dispatch::nvmlinit(){
 | 
					bool dispatch::nvmlinit(){
 | 
				
			||||||
 | 
					  #ifdef _WIN32
 | 
				
			||||||
 | 
					  if(nvml_==nullptr)
 | 
				
			||||||
 | 
					    nvml_ = dlopen("nvml.dll", RTLD_LAZY);
 | 
				
			||||||
 | 
					  #else
 | 
				
			||||||
  if(nvml_==nullptr)
 | 
					  if(nvml_==nullptr)
 | 
				
			||||||
    nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY);
 | 
					    nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY);
 | 
				
			||||||
 | 
					  #endif
 | 
				
			||||||
  nvmlReturn_t (*fptr)();
 | 
					  nvmlReturn_t (*fptr)();
 | 
				
			||||||
  nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2");
 | 
					  nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2");
 | 
				
			||||||
  *reinterpret_cast<void **>(&fptr) = nvmlInit_v2_;
 | 
					  *reinterpret_cast<void **>(&fptr) = nvmlInit_v2_;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -20,7 +20,9 @@
 | 
				
			|||||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 | 
					* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 | 
				
			||||||
*/
 | 
					*/
 | 
				
			||||||
#include <fstream>
 | 
					#include <fstream>
 | 
				
			||||||
#include <unistd.h>
 | 
					#if __has_include(<unistd.h>)
 | 
				
			||||||
 | 
					    #include <unistd.h>
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
#include <memory>
 | 
					#include <memory>
 | 
				
			||||||
#include <regex>
 | 
					#include <regex>
 | 
				
			||||||
#include "triton/driver/llvm.h"
 | 
					#include "triton/driver/llvm.h"
 | 
				
			||||||
@@ -185,8 +187,10 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) {
 | 
				
			|||||||
  // compile ptx with ptxas
 | 
					  // compile ptx with ptxas
 | 
				
			||||||
  char _fsrc[L_tmpnam];
 | 
					  char _fsrc[L_tmpnam];
 | 
				
			||||||
  char _flog[L_tmpnam];
 | 
					  char _flog[L_tmpnam];
 | 
				
			||||||
  std::string fsrc = std::tmpnam(_fsrc);
 | 
					  std::tmpnam(_fsrc);
 | 
				
			||||||
  std::string flog = std::tmpnam(_flog);
 | 
					  std::tmpnam(_flog);
 | 
				
			||||||
 | 
					  std::string fsrc = _fsrc;
 | 
				
			||||||
 | 
					  std::string flog = _flog;
 | 
				
			||||||
  std::string fbin = fsrc + ".o";
 | 
					  std::string fbin = fsrc + ".o";
 | 
				
			||||||
  const char* _fbin = fbin.c_str();
 | 
					  const char* _fbin = fbin.c_str();
 | 
				
			||||||
  std::ofstream ofs(fsrc);
 | 
					  std::ofstream ofs(fsrc);
 | 
				
			||||||
@@ -367,8 +371,8 @@ hipModule_t amdgpu_to_hipmodule(const std::string& path) {
 | 
				
			|||||||
  hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer,
 | 
					  hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer,
 | 
				
			||||||
                            hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer,
 | 
					                            hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer,
 | 
				
			||||||
                            hipJitOptionLogVerbose};
 | 
					                            hipJitOptionLogVerbose};
 | 
				
			||||||
  unsigned int errbufsize = 8192;
 | 
					  const unsigned int errbufsize = 8192;
 | 
				
			||||||
  unsigned int logbufsize = 8192;
 | 
					  const unsigned int logbufsize = 8192;
 | 
				
			||||||
  char _err[errbufsize];
 | 
					  char _err[errbufsize];
 | 
				
			||||||
  char _log[logbufsize];
 | 
					  char _log[logbufsize];
 | 
				
			||||||
  void* optval[] = {(void*)(uintptr_t)errbufsize,
 | 
					  void* optval[] = {(void*)(uintptr_t)errbufsize,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -23,6 +23,8 @@ def get_llvm():
 | 
				
			|||||||
    paths = [p for p in paths if p is not None]
 | 
					    paths = [p for p in paths if p is not None]
 | 
				
			||||||
    if paths:
 | 
					    if paths:
 | 
				
			||||||
        return '', ''
 | 
					        return '', ''
 | 
				
			||||||
 | 
					    if platform.system() == "Windows":
 | 
				
			||||||
 | 
					        return '', ''
 | 
				
			||||||
    # download if nothing is installed
 | 
					    # download if nothing is installed
 | 
				
			||||||
    name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
 | 
					    name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
 | 
				
			||||||
    dir = '/tmp'
 | 
					    dir = '/tmp'
 | 
				
			||||||
@@ -104,7 +106,7 @@ class CMakeBuild(build_ext):
 | 
				
			|||||||
        build_args = ["--config", cfg]
 | 
					        build_args = ["--config", cfg]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if platform.system() == "Windows":
 | 
					        if platform.system() == "Windows":
 | 
				
			||||||
            cmake_args += ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
 | 
					            cmake_args += ["-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
 | 
				
			||||||
            if sys.maxsize > 2**32:
 | 
					            if sys.maxsize > 2**32:
 | 
				
			||||||
                cmake_args += ["-A", "x64"]
 | 
					                cmake_args += ["-A", "x64"]
 | 
				
			||||||
            build_args += ["--", "/m"]
 | 
					            build_args += ["--", "/m"]
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,6 +15,7 @@
 | 
				
			|||||||
#include <pybind11/stl.h>
 | 
					#include <pybind11/stl.h>
 | 
				
			||||||
#include "Python.h"
 | 
					#include "Python.h"
 | 
				
			||||||
#include <regex>
 | 
					#include <regex>
 | 
				
			||||||
 | 
					#include <sstream>
 | 
				
			||||||
#include <string>
 | 
					#include <string>
 | 
				
			||||||
#include "llvm/IR/Module.h"
 | 
					#include "llvm/IR/Module.h"
 | 
				
			||||||
#include "llvm/IR/LegacyPassManager.h"
 | 
					#include "llvm/IR/LegacyPassManager.h"
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -25,13 +25,13 @@ def get_p2p_matrix():
 | 
				
			|||||||
def get_p2p_devices():
 | 
					def get_p2p_devices():
 | 
				
			||||||
    matrix = get_p2p_matrix()
 | 
					    matrix = get_p2p_matrix()
 | 
				
			||||||
    idx = np.where(matrix == "OK")
 | 
					    idx = np.where(matrix == "OK")
 | 
				
			||||||
    return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
 | 
					    return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_non_p2p_devices():
 | 
					def get_non_p2p_devices():
 | 
				
			||||||
    matrix = get_p2p_matrix()
 | 
					    matrix = get_p2p_matrix()
 | 
				
			||||||
    idx = np.where(matrix == "NS")
 | 
					    idx = np.where(matrix == "NS")
 | 
				
			||||||
    return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
 | 
					    return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
p2p_devices = get_p2p_devices()
 | 
					p2p_devices = get_p2p_devices()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -358,9 +358,6 @@ class CodeGenerator(ast.NodeVisitor):
 | 
				
			|||||||
        for stmt in node.orelse:
 | 
					        for stmt in node.orelse:
 | 
				
			||||||
            ast.NodeVisitor.generic_visit(self, stmt)
 | 
					            ast.NodeVisitor.generic_visit(self, stmt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def visit_Str(self, node):
 | 
					 | 
				
			||||||
        return ast.literal_eval(node)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def visit_Subscript(self, node):
 | 
					    def visit_Subscript(self, node):
 | 
				
			||||||
        assert node.ctx.__class__.__name__ == "Load"
 | 
					        assert node.ctx.__class__.__name__ == "Load"
 | 
				
			||||||
        lhs = self.visit(node.value)
 | 
					        lhs = self.visit(node.value)
 | 
				
			||||||
@@ -441,9 +438,6 @@ class CodeGenerator(ast.NodeVisitor):
 | 
				
			|||||||
    def visit_Index(self, node):
 | 
					    def visit_Index(self, node):
 | 
				
			||||||
        return self.visit(node.value)
 | 
					        return self.visit(node.value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def visit_NameConstant(self, node):
 | 
					 | 
				
			||||||
        return node.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def visit_keyword(self, node):
 | 
					    def visit_keyword(self, node):
 | 
				
			||||||
        return {node.arg: self.visit(node.value)}
 | 
					        return {node.arg: self.visit(node.value)}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -460,10 +454,23 @@ class CodeGenerator(ast.NodeVisitor):
 | 
				
			|||||||
        if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
 | 
					        if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
 | 
				
			||||||
                sys.modules[fn.__module__] is triton.language.core:
 | 
					                sys.modules[fn.__module__] is triton.language.core:
 | 
				
			||||||
            return fn(*args, _builder=self.builder, **kws)
 | 
					            return fn(*args, _builder=self.builder, **kws)
 | 
				
			||||||
 | 
					        if fn in self.builtins.values():
 | 
				
			||||||
 | 
					            args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
 | 
				
			||||||
 | 
					                    for arg in args]
 | 
				
			||||||
        return fn(*args, **kws)
 | 
					        return fn(*args, **kws)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def visit_Num(self, node):
 | 
					    def visit_Constant(self, node):
 | 
				
			||||||
        return triton.language.constexpr(node.n)
 | 
					        return triton.language.constexpr(node.value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if sys.version_info < (3, 8):
 | 
				
			||||||
 | 
					        def visit_NameConstant(self, node):
 | 
				
			||||||
 | 
					            return triton.language.constexpr(node.value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def visit_Num(self, node):
 | 
				
			||||||
 | 
					            return triton.language.constexpr(node.n)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def visit_Str(self, node):
 | 
				
			||||||
 | 
					            return triton.language.constexpr(ast.literal_eval(node))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def visit_Attribute(self, node):
 | 
					    def visit_Attribute(self, node):
 | 
				
			||||||
        lhs = self.visit(node.value)
 | 
					        lhs = self.visit(node.value)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -130,6 +130,94 @@ float64 = dtype(ir.type.get_fp64)
 | 
				
			|||||||
# pointer types
 | 
					# pointer types
 | 
				
			||||||
pi32_t = pointer_dtype(int32)
 | 
					pi32_t = pointer_dtype(int32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# -----------------------
 | 
				
			||||||
 | 
					# constexpr
 | 
				
			||||||
 | 
					# -----------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class constexpr:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    This class is used to store a value that is known at compile-time.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, value):
 | 
				
			||||||
 | 
					        if isinstance(value, constexpr):
 | 
				
			||||||
 | 
					            self.value = value.value
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.value = value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __repr__(self) -> str:
 | 
				
			||||||
 | 
					        return f"constexpr[{self.value}]"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    def __add__(self, other):
 | 
				
			||||||
 | 
					        return self.value + other.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __radd__(self, other):
 | 
				
			||||||
 | 
					        return other.value + self.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __sub__(self, other):
 | 
				
			||||||
 | 
					        return self.value - other.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __rsub__(self, other):
 | 
				
			||||||
 | 
					        return other.value - self.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __mul__(self, other):
 | 
				
			||||||
 | 
					        return self.value * other.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __rmul__(self, other):
 | 
				
			||||||
 | 
					        return other.value * self.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __truediv__(self, other):
 | 
				
			||||||
 | 
					        return self.value / other.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __rtruediv__(self, other):
 | 
				
			||||||
 | 
					        return other.value / self.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __floordiv__(self, other):
 | 
				
			||||||
 | 
					        return self.value // other.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __rfloordiv__(self, other):
 | 
				
			||||||
 | 
					        return other.value // self.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __gt__(self, other):
 | 
				
			||||||
 | 
					        return self.value > other.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __rgt__(self, other):
 | 
				
			||||||
 | 
					        return other.value > self.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __ge__(self, other):
 | 
				
			||||||
 | 
					        return self.value >= other.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __rge__(self, other):
 | 
				
			||||||
 | 
					        return other.value >= self.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __lt__(self, other):
 | 
				
			||||||
 | 
					        return self.value < other.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __rlt__(self, other):
 | 
				
			||||||
 | 
					        return other.value < self.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __le__(self, other):
 | 
				
			||||||
 | 
					        return self.value <= other.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __rle__(self, other):
 | 
				
			||||||
 | 
					        return other.value <= self.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __eq__(self, other):
 | 
				
			||||||
 | 
					        return self.value == other.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __ne__(self, other):
 | 
				
			||||||
 | 
					        return self.value != other.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __bool__(self):
 | 
				
			||||||
 | 
					        return bool(self.value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, *args, **kwds):
 | 
				
			||||||
 | 
					        return self.value(*args, **kwds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class block:
 | 
					class block:
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
@@ -296,7 +384,7 @@ class block:
 | 
				
			|||||||
        dst_shape = []
 | 
					        dst_shape = []
 | 
				
			||||||
        curr = 0
 | 
					        curr = 0
 | 
				
			||||||
        for sl in slices:
 | 
					        for sl in slices:
 | 
				
			||||||
            if sl is None:
 | 
					            if isinstance(sl, constexpr) and sl.value is None:
 | 
				
			||||||
                dst_shape.append(1)
 | 
					                dst_shape.append(1)
 | 
				
			||||||
            elif sl == slice(None, None, None):
 | 
					            elif sl == slice(None, None, None):
 | 
				
			||||||
                dst_shape.append(src_shape[curr].value)
 | 
					                dst_shape.append(src_shape[curr].value)
 | 
				
			||||||
@@ -312,93 +400,6 @@ class block:
 | 
				
			|||||||
        return frontend.cast(self, dtype, _builder)
 | 
					        return frontend.cast(self, dtype, _builder)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# -----------------------
 | 
					 | 
				
			||||||
# constexpr
 | 
					 | 
				
			||||||
# -----------------------
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class constexpr:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    This class is used to store a value that is known at compile-time.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, value):
 | 
					 | 
				
			||||||
        if isinstance(value, constexpr):
 | 
					 | 
				
			||||||
            self.value = value.value
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self.value = value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __repr__(self) -> str:
 | 
					 | 
				
			||||||
        return f"constexpr[{self.value}]"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    #
 | 
					 | 
				
			||||||
    def __add__(self, other):
 | 
					 | 
				
			||||||
        return self.value + other.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __radd__(self, other):
 | 
					 | 
				
			||||||
        return other.value + self.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __sub__(self, other):
 | 
					 | 
				
			||||||
        return self.value - other.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __rsub__(self, other):
 | 
					 | 
				
			||||||
        return other.value - self.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __mul__(self, other):
 | 
					 | 
				
			||||||
        return self.value * other.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __rmul__(self, other):
 | 
					 | 
				
			||||||
        return other.value * self.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __truediv__(self, other):
 | 
					 | 
				
			||||||
        return self.value / other.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __rtruediv__(self, other):
 | 
					 | 
				
			||||||
        return other.value / self.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __floordiv__(self, other):
 | 
					 | 
				
			||||||
        return self.value // other.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __rfloordiv__(self, other):
 | 
					 | 
				
			||||||
        return other.value // self.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    #
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __gt__(self, other):
 | 
					 | 
				
			||||||
        return self.value > other.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __rgt__(self, other):
 | 
					 | 
				
			||||||
        return other.value > self.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __ge__(self, other):
 | 
					 | 
				
			||||||
        return self.value >= other.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __rge__(self, other):
 | 
					 | 
				
			||||||
        return other.value >= self.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __lt__(self, other):
 | 
					 | 
				
			||||||
        return self.value < other.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __rlt__(self, other):
 | 
					 | 
				
			||||||
        return other.value < self.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __le__(self, other):
 | 
					 | 
				
			||||||
        return self.value <= other.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __rle__(self, other):
 | 
					 | 
				
			||||||
        return other.value <= self.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __eq__(self, other):
 | 
					 | 
				
			||||||
        return self.value == other.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __ne__(self, other):
 | 
					 | 
				
			||||||
        return self.value != other.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __bool__(self):
 | 
					 | 
				
			||||||
        return bool(self.value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __call__(self, *args, **kwds):
 | 
					 | 
				
			||||||
        return self.value(*args, **kwds)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# -----------------------
 | 
					# -----------------------
 | 
				
			||||||
# SPMD Programming Model
 | 
					# SPMD Programming Model
 | 
				
			||||||
# -----------------------
 | 
					# -----------------------
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user