[ALL] Merge master (#447)
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,6 +1,9 @@
|
|||||||
|
build/
|
||||||
|
|
||||||
__pycache__
|
__pycache__
|
||||||
.pytest_cache
|
.pytest_cache
|
||||||
|
|
||||||
python/build/
|
python/build/
|
||||||
python/triton.egg-info/
|
python/triton.egg-info/
|
||||||
|
python/triton/_C/libtriton.pyd
|
||||||
python/triton/_C/libtriton.so
|
python/triton/_C/libtriton.so
|
||||||
|
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[submodule "deps/dlfcn-win32"]
|
||||||
|
path = deps/dlfcn-win32
|
||||||
|
url = https://github.com/dlfcn-win32/dlfcn-win32.git
|
@@ -1,6 +1,8 @@
|
|||||||
cmake_minimum_required(VERSION 3.6)
|
cmake_minimum_required(VERSION 3.6)
|
||||||
include(ExternalProject)
|
include(ExternalProject)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
|
||||||
if(NOT TRITON_LLVM_BUILD_DIR)
|
if(NOT TRITON_LLVM_BUILD_DIR)
|
||||||
set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
|
set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
|
||||||
endif()
|
endif()
|
||||||
@@ -8,7 +10,9 @@ endif()
|
|||||||
|
|
||||||
project(triton)
|
project(triton)
|
||||||
include(CTest)
|
include(CTest)
|
||||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
if(NOT WIN32)
|
||||||
|
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Options
|
# Options
|
||||||
option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
|
option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
|
||||||
@@ -20,10 +24,19 @@ if(NOT CMAKE_BUILD_TYPE)
|
|||||||
set(CMAKE_BUILD_TYPE "Release")
|
set(CMAKE_BUILD_TYPE "Release")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
find_library(TERMINFO_LIBRARY tinfo)
|
if(NOT WIN32)
|
||||||
|
find_library(TERMINFO_LIBRARY tinfo)
|
||||||
|
endif()
|
||||||
|
|
||||||
# Compiler flags
|
# Compiler flags
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||||
|
|
||||||
|
if(WIN32)
|
||||||
|
SET(BUILD_SHARED_LIBS OFF)
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/deps/dlfcn-win32/src)
|
||||||
|
add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
|
||||||
|
endif()
|
||||||
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
|
||||||
|
|
||||||
|
|
||||||
@@ -31,7 +44,20 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
|
|||||||
# LLVM
|
# LLVM
|
||||||
##########
|
##########
|
||||||
if("${LLVM_LIBRARY_DIR}" STREQUAL "")
|
if("${LLVM_LIBRARY_DIR}" STREQUAL "")
|
||||||
find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu")
|
if(WIN32)
|
||||||
|
find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu)
|
||||||
|
|
||||||
|
include_directories(${LLVM_INCLUDE_DIRS})
|
||||||
|
separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS})
|
||||||
|
add_definitions(${LLVM_DEFINITIONS_LIST})
|
||||||
|
|
||||||
|
llvm_map_components_to_libnames(LLVM_LIBRARIES support core
|
||||||
|
NVPTXInfo nvptxcodegen
|
||||||
|
AMDGPUInfo AMDGPUcodegen
|
||||||
|
)
|
||||||
|
else()
|
||||||
|
find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu")
|
||||||
|
endif()
|
||||||
message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
|
message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
|
||||||
if(APPLE)
|
if(APPLE)
|
||||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14")
|
set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14")
|
||||||
@@ -108,12 +134,25 @@ endif()
|
|||||||
|
|
||||||
# Triton
|
# Triton
|
||||||
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||||
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
if (WIN32 AND BUILD_PYTHON_MODULE)
|
||||||
|
find_package(Python3 REQUIRED COMPONENTS Development)
|
||||||
|
Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||||
|
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
|
||||||
|
set_target_properties(triton PROPERTIES PREFIX "lib")
|
||||||
|
else()
|
||||||
|
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||||
|
endif()
|
||||||
|
|
||||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||||
target_link_libraries(triton ${LLVM_LIBRARIES} z ${TERMINFO_LIBRARY})
|
|
||||||
|
if(WIN32)
|
||||||
|
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
|
||||||
|
else()
|
||||||
|
target_link_libraries(triton ${LLVM_LIBRARIES} z ${TERMINFO_LIBRARY})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
if(BUILD_PYTHON_MODULE)
|
if(BUILD_PYTHON_MODULE AND NOT WIN32)
|
||||||
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
|
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
|
||||||
# Check if the platform is MacOS
|
# Check if the platform is MacOS
|
||||||
if(APPLE)
|
if(APPLE)
|
||||||
|
1
deps/dlfcn-win32
vendored
Submodule
1
deps/dlfcn-win32
vendored
Submodule
Submodule deps/dlfcn-win32 added at 522c301ec3
@@ -13,6 +13,14 @@ namespace tools
|
|||||||
{
|
{
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define popen _popen
|
||||||
|
#define pclose _pclose
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef WEXITSTATUS
|
||||||
|
#define WEXITSTATUS(stat_val) ((unsigned)(stat_val) & 255)
|
||||||
|
#endif
|
||||||
|
|
||||||
int exec(const std::string& cmd, std::string& result) {
|
int exec(const std::string& cmd, std::string& result) {
|
||||||
char buffer[128];
|
char buffer[128];
|
||||||
|
@@ -33,19 +33,10 @@ namespace tools
|
|||||||
|
|
||||||
inline std::string getenv(const char * name)
|
inline std::string getenv(const char * name)
|
||||||
{
|
{
|
||||||
#ifdef _MSC_VER
|
const char * cstr = std::getenv(name);
|
||||||
char* cache_path = 0;
|
|
||||||
std::size_t sz = 0;
|
|
||||||
_dupenv_s(&cache_path, &sz, name);
|
|
||||||
#else
|
|
||||||
const char * cstr = std::getenv(name);
|
|
||||||
#endif
|
|
||||||
if(!cstr)
|
if(!cstr)
|
||||||
return "";
|
return "";
|
||||||
std::string result(cstr);
|
std::string result(cstr);
|
||||||
#ifdef _MSC_VER
|
|
||||||
free(cache_path);
|
|
||||||
#endif
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -449,18 +449,18 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0
|
|||||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t"
|
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t"
|
||||||
"}", "=r,=r,r", false);
|
"}", "=r,=r,r", false);
|
||||||
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
|
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
|
||||||
packed_in = insert_elt(packed_in, in0, (int)0);
|
packed_in = insert_elt(packed_in, in0, (uint64_t)0);
|
||||||
packed_in = insert_elt(packed_in, in1, (int)1);
|
packed_in = insert_elt(packed_in, in1, (uint64_t)1);
|
||||||
packed_in = insert_elt(packed_in, in2, (int)2);
|
packed_in = insert_elt(packed_in, in2, (uint64_t)2);
|
||||||
packed_in = insert_elt(packed_in, in3, (int)3);
|
packed_in = insert_elt(packed_in, in3, (uint64_t)3);
|
||||||
Value *in = bit_cast(packed_in, i32_ty);
|
Value *in = bit_cast(packed_in, i32_ty);
|
||||||
Value *ret = call(ptx, {in});
|
Value *ret = call(ptx, {in});
|
||||||
Value *packed_ret0 = extract_val(ret, {0});
|
Value *packed_ret0 = extract_val(ret, {0});
|
||||||
Value *packed_ret1 = extract_val(ret, {1});
|
Value *packed_ret1 = extract_val(ret, {1});
|
||||||
Value *ret0 = extract_elt(packed_ret0, (int)0);
|
Value *ret0 = extract_elt(packed_ret0, (uint64_t)0);
|
||||||
Value *ret1 = extract_elt(packed_ret0, (int)1);
|
Value *ret1 = extract_elt(packed_ret0, (uint64_t)1);
|
||||||
Value *ret2 = extract_elt(packed_ret1, (int)0);
|
Value *ret2 = extract_elt(packed_ret1, (uint64_t)0);
|
||||||
Value *ret3 = extract_elt(packed_ret1, (int)1);
|
Value *ret3 = extract_elt(packed_ret1, (uint64_t)1);
|
||||||
return std::make_tuple(ret0, ret1, ret2, ret3);
|
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -717,11 +717,11 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
// ---
|
// ---
|
||||||
// finally call inline ASM
|
// finally call inline ASM
|
||||||
// ---
|
// ---
|
||||||
InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
|
InlineAsm *inlineAsm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
|
||||||
std::vector<Value*> args = {pred, ptr};
|
std::vector<Value*> args = {pred, ptr};
|
||||||
for(Value *v: others)
|
for(Value *v: others)
|
||||||
args.push_back(v);
|
args.push_back(v);
|
||||||
Value *_ret = call(_asm, args);
|
Value *_ret = call(inlineAsm, args);
|
||||||
// ---
|
// ---
|
||||||
// extract and store return values
|
// extract and store return values
|
||||||
// ---
|
// ---
|
||||||
|
@@ -91,9 +91,13 @@ void* dispatch::fname ## _;
|
|||||||
|
|
||||||
bool dispatch::cuinit(){
|
bool dispatch::cuinit(){
|
||||||
if(cuda_==nullptr){
|
if(cuda_==nullptr){
|
||||||
|
#ifdef _WIN32
|
||||||
|
cuda_ = dlopen("cudart64_110.dll", RTLD_LAZY);
|
||||||
|
#else
|
||||||
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
|
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
|
||||||
if(!cuda_)
|
if(!cuda_)
|
||||||
cuda_ = dlopen("libcuda.so.1", RTLD_LAZY);
|
cuda_ = dlopen("libcuda.so.1", RTLD_LAZY);
|
||||||
|
#endif
|
||||||
if(!cuda_)
|
if(!cuda_)
|
||||||
throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH.");
|
throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH.");
|
||||||
}
|
}
|
||||||
@@ -176,8 +180,13 @@ CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent)
|
|||||||
* NVML
|
* NVML
|
||||||
* ------------------- */
|
* ------------------- */
|
||||||
bool dispatch::nvmlinit(){
|
bool dispatch::nvmlinit(){
|
||||||
|
#ifdef _WIN32
|
||||||
|
if(nvml_==nullptr)
|
||||||
|
nvml_ = dlopen("nvml.dll", RTLD_LAZY);
|
||||||
|
#else
|
||||||
if(nvml_==nullptr)
|
if(nvml_==nullptr)
|
||||||
nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY);
|
nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY);
|
||||||
|
#endif
|
||||||
nvmlReturn_t (*fptr)();
|
nvmlReturn_t (*fptr)();
|
||||||
nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2");
|
nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2");
|
||||||
*reinterpret_cast<void **>(&fptr) = nvmlInit_v2_;
|
*reinterpret_cast<void **>(&fptr) = nvmlInit_v2_;
|
||||||
|
@@ -20,7 +20,9 @@
|
|||||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
*/
|
*/
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <unistd.h>
|
#if __has_include(<unistd.h>)
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include "triton/driver/llvm.h"
|
#include "triton/driver/llvm.h"
|
||||||
@@ -185,8 +187,10 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) {
|
|||||||
// compile ptx with ptxas
|
// compile ptx with ptxas
|
||||||
char _fsrc[L_tmpnam];
|
char _fsrc[L_tmpnam];
|
||||||
char _flog[L_tmpnam];
|
char _flog[L_tmpnam];
|
||||||
std::string fsrc = std::tmpnam(_fsrc);
|
std::tmpnam(_fsrc);
|
||||||
std::string flog = std::tmpnam(_flog);
|
std::tmpnam(_flog);
|
||||||
|
std::string fsrc = _fsrc;
|
||||||
|
std::string flog = _flog;
|
||||||
std::string fbin = fsrc + ".o";
|
std::string fbin = fsrc + ".o";
|
||||||
const char* _fbin = fbin.c_str();
|
const char* _fbin = fbin.c_str();
|
||||||
std::ofstream ofs(fsrc);
|
std::ofstream ofs(fsrc);
|
||||||
@@ -367,8 +371,8 @@ hipModule_t amdgpu_to_hipmodule(const std::string& path) {
|
|||||||
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer,
|
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer,
|
||||||
hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer,
|
hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer,
|
||||||
hipJitOptionLogVerbose};
|
hipJitOptionLogVerbose};
|
||||||
unsigned int errbufsize = 8192;
|
const unsigned int errbufsize = 8192;
|
||||||
unsigned int logbufsize = 8192;
|
const unsigned int logbufsize = 8192;
|
||||||
char _err[errbufsize];
|
char _err[errbufsize];
|
||||||
char _log[logbufsize];
|
char _log[logbufsize];
|
||||||
void* optval[] = {(void*)(uintptr_t)errbufsize,
|
void* optval[] = {(void*)(uintptr_t)errbufsize,
|
||||||
|
@@ -23,6 +23,8 @@ def get_llvm():
|
|||||||
paths = [p for p in paths if p is not None]
|
paths = [p for p in paths if p is not None]
|
||||||
if paths:
|
if paths:
|
||||||
return '', ''
|
return '', ''
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
return '', ''
|
||||||
# download if nothing is installed
|
# download if nothing is installed
|
||||||
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
|
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
|
||||||
dir = '/tmp'
|
dir = '/tmp'
|
||||||
@@ -104,7 +106,7 @@ class CMakeBuild(build_ext):
|
|||||||
build_args = ["--config", cfg]
|
build_args = ["--config", cfg]
|
||||||
|
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
cmake_args += ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
|
cmake_args += ["-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
|
||||||
if sys.maxsize > 2**32:
|
if sys.maxsize > 2**32:
|
||||||
cmake_args += ["-A", "x64"]
|
cmake_args += ["-A", "x64"]
|
||||||
build_args += ["--", "/m"]
|
build_args += ["--", "/m"]
|
||||||
|
@@ -15,6 +15,7 @@
|
|||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
#include "Python.h"
|
#include "Python.h"
|
||||||
#include <regex>
|
#include <regex>
|
||||||
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/IR/LegacyPassManager.h"
|
#include "llvm/IR/LegacyPassManager.h"
|
||||||
|
@@ -25,13 +25,13 @@ def get_p2p_matrix():
|
|||||||
def get_p2p_devices():
|
def get_p2p_devices():
|
||||||
matrix = get_p2p_matrix()
|
matrix = get_p2p_matrix()
|
||||||
idx = np.where(matrix == "OK")
|
idx = np.where(matrix == "OK")
|
||||||
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
|
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
|
||||||
|
|
||||||
|
|
||||||
def get_non_p2p_devices():
|
def get_non_p2p_devices():
|
||||||
matrix = get_p2p_matrix()
|
matrix = get_p2p_matrix()
|
||||||
idx = np.where(matrix == "NS")
|
idx = np.where(matrix == "NS")
|
||||||
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
|
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
|
||||||
|
|
||||||
|
|
||||||
p2p_devices = get_p2p_devices()
|
p2p_devices = get_p2p_devices()
|
||||||
|
@@ -358,9 +358,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
for stmt in node.orelse:
|
for stmt in node.orelse:
|
||||||
ast.NodeVisitor.generic_visit(self, stmt)
|
ast.NodeVisitor.generic_visit(self, stmt)
|
||||||
|
|
||||||
def visit_Str(self, node):
|
|
||||||
return ast.literal_eval(node)
|
|
||||||
|
|
||||||
def visit_Subscript(self, node):
|
def visit_Subscript(self, node):
|
||||||
assert node.ctx.__class__.__name__ == "Load"
|
assert node.ctx.__class__.__name__ == "Load"
|
||||||
lhs = self.visit(node.value)
|
lhs = self.visit(node.value)
|
||||||
@@ -441,9 +438,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
def visit_Index(self, node):
|
def visit_Index(self, node):
|
||||||
return self.visit(node.value)
|
return self.visit(node.value)
|
||||||
|
|
||||||
def visit_NameConstant(self, node):
|
|
||||||
return node.value
|
|
||||||
|
|
||||||
def visit_keyword(self, node):
|
def visit_keyword(self, node):
|
||||||
return {node.arg: self.visit(node.value)}
|
return {node.arg: self.visit(node.value)}
|
||||||
|
|
||||||
@@ -460,10 +454,23 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
||||||
sys.modules[fn.__module__] is triton.language.core:
|
sys.modules[fn.__module__] is triton.language.core:
|
||||||
return fn(*args, _builder=self.builder, **kws)
|
return fn(*args, _builder=self.builder, **kws)
|
||||||
|
if fn in self.builtins.values():
|
||||||
|
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
||||||
|
for arg in args]
|
||||||
return fn(*args, **kws)
|
return fn(*args, **kws)
|
||||||
|
|
||||||
def visit_Num(self, node):
|
def visit_Constant(self, node):
|
||||||
return triton.language.constexpr(node.n)
|
return triton.language.constexpr(node.value)
|
||||||
|
|
||||||
|
if sys.version_info < (3, 8):
|
||||||
|
def visit_NameConstant(self, node):
|
||||||
|
return triton.language.constexpr(node.value)
|
||||||
|
|
||||||
|
def visit_Num(self, node):
|
||||||
|
return triton.language.constexpr(node.n)
|
||||||
|
|
||||||
|
def visit_Str(self, node):
|
||||||
|
return triton.language.constexpr(ast.literal_eval(node))
|
||||||
|
|
||||||
def visit_Attribute(self, node):
|
def visit_Attribute(self, node):
|
||||||
lhs = self.visit(node.value)
|
lhs = self.visit(node.value)
|
||||||
|
@@ -130,6 +130,94 @@ float64 = dtype(ir.type.get_fp64)
|
|||||||
# pointer types
|
# pointer types
|
||||||
pi32_t = pointer_dtype(int32)
|
pi32_t = pointer_dtype(int32)
|
||||||
|
|
||||||
|
# -----------------------
|
||||||
|
# constexpr
|
||||||
|
# -----------------------
|
||||||
|
|
||||||
|
|
||||||
|
class constexpr:
|
||||||
|
"""
|
||||||
|
This class is used to store a value that is known at compile-time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, value):
|
||||||
|
if isinstance(value, constexpr):
|
||||||
|
self.value = value.value
|
||||||
|
else:
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"constexpr[{self.value}]"
|
||||||
|
|
||||||
|
#
|
||||||
|
def __add__(self, other):
|
||||||
|
return self.value + other.value
|
||||||
|
|
||||||
|
def __radd__(self, other):
|
||||||
|
return other.value + self.value
|
||||||
|
|
||||||
|
def __sub__(self, other):
|
||||||
|
return self.value - other.value
|
||||||
|
|
||||||
|
def __rsub__(self, other):
|
||||||
|
return other.value - self.value
|
||||||
|
|
||||||
|
def __mul__(self, other):
|
||||||
|
return self.value * other.value
|
||||||
|
|
||||||
|
def __rmul__(self, other):
|
||||||
|
return other.value * self.value
|
||||||
|
|
||||||
|
def __truediv__(self, other):
|
||||||
|
return self.value / other.value
|
||||||
|
|
||||||
|
def __rtruediv__(self, other):
|
||||||
|
return other.value / self.value
|
||||||
|
|
||||||
|
def __floordiv__(self, other):
|
||||||
|
return self.value // other.value
|
||||||
|
|
||||||
|
def __rfloordiv__(self, other):
|
||||||
|
return other.value // self.value
|
||||||
|
|
||||||
|
#
|
||||||
|
|
||||||
|
def __gt__(self, other):
|
||||||
|
return self.value > other.value
|
||||||
|
|
||||||
|
def __rgt__(self, other):
|
||||||
|
return other.value > self.value
|
||||||
|
|
||||||
|
def __ge__(self, other):
|
||||||
|
return self.value >= other.value
|
||||||
|
|
||||||
|
def __rge__(self, other):
|
||||||
|
return other.value >= self.value
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
return self.value < other.value
|
||||||
|
|
||||||
|
def __rlt__(self, other):
|
||||||
|
return other.value < self.value
|
||||||
|
|
||||||
|
def __le__(self, other):
|
||||||
|
return self.value <= other.value
|
||||||
|
|
||||||
|
def __rle__(self, other):
|
||||||
|
return other.value <= self.value
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.value == other.value
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return self.value != other.value
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return bool(self.value)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwds):
|
||||||
|
return self.value(*args, **kwds)
|
||||||
|
|
||||||
|
|
||||||
class block:
|
class block:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -296,7 +384,7 @@ class block:
|
|||||||
dst_shape = []
|
dst_shape = []
|
||||||
curr = 0
|
curr = 0
|
||||||
for sl in slices:
|
for sl in slices:
|
||||||
if sl is None:
|
if isinstance(sl, constexpr) and sl.value is None:
|
||||||
dst_shape.append(1)
|
dst_shape.append(1)
|
||||||
elif sl == slice(None, None, None):
|
elif sl == slice(None, None, None):
|
||||||
dst_shape.append(src_shape[curr].value)
|
dst_shape.append(src_shape[curr].value)
|
||||||
@@ -312,93 +400,6 @@ class block:
|
|||||||
return frontend.cast(self, dtype, _builder)
|
return frontend.cast(self, dtype, _builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
|
||||||
# constexpr
|
|
||||||
# -----------------------
|
|
||||||
|
|
||||||
class constexpr:
|
|
||||||
"""
|
|
||||||
This class is used to store a value that is known at compile-time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, value):
|
|
||||||
if isinstance(value, constexpr):
|
|
||||||
self.value = value.value
|
|
||||||
else:
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"constexpr[{self.value}]"
|
|
||||||
|
|
||||||
#
|
|
||||||
def __add__(self, other):
|
|
||||||
return self.value + other.value
|
|
||||||
|
|
||||||
def __radd__(self, other):
|
|
||||||
return other.value + self.value
|
|
||||||
|
|
||||||
def __sub__(self, other):
|
|
||||||
return self.value - other.value
|
|
||||||
|
|
||||||
def __rsub__(self, other):
|
|
||||||
return other.value - self.value
|
|
||||||
|
|
||||||
def __mul__(self, other):
|
|
||||||
return self.value * other.value
|
|
||||||
|
|
||||||
def __rmul__(self, other):
|
|
||||||
return other.value * self.value
|
|
||||||
|
|
||||||
def __truediv__(self, other):
|
|
||||||
return self.value / other.value
|
|
||||||
|
|
||||||
def __rtruediv__(self, other):
|
|
||||||
return other.value / self.value
|
|
||||||
|
|
||||||
def __floordiv__(self, other):
|
|
||||||
return self.value // other.value
|
|
||||||
|
|
||||||
def __rfloordiv__(self, other):
|
|
||||||
return other.value // self.value
|
|
||||||
|
|
||||||
#
|
|
||||||
|
|
||||||
def __gt__(self, other):
|
|
||||||
return self.value > other.value
|
|
||||||
|
|
||||||
def __rgt__(self, other):
|
|
||||||
return other.value > self.value
|
|
||||||
|
|
||||||
def __ge__(self, other):
|
|
||||||
return self.value >= other.value
|
|
||||||
|
|
||||||
def __rge__(self, other):
|
|
||||||
return other.value >= self.value
|
|
||||||
|
|
||||||
def __lt__(self, other):
|
|
||||||
return self.value < other.value
|
|
||||||
|
|
||||||
def __rlt__(self, other):
|
|
||||||
return other.value < self.value
|
|
||||||
|
|
||||||
def __le__(self, other):
|
|
||||||
return self.value <= other.value
|
|
||||||
|
|
||||||
def __rle__(self, other):
|
|
||||||
return other.value <= self.value
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
return self.value == other.value
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
|
||||||
return self.value != other.value
|
|
||||||
|
|
||||||
def __bool__(self):
|
|
||||||
return bool(self.value)
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwds):
|
|
||||||
return self.value(*args, **kwds)
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
# SPMD Programming Model
|
# SPMD Programming Model
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
Reference in New Issue
Block a user