[RUNTIME] Major code cleanup (#711)

This PR does the following:
- CUDA utilities (e.g., cuGetInfo) won't be compiled as part of libtriton.so anymore.
- Refactoring driver/llvm.cc to split it between PTX codegen and python.
- By extension this will also deprecate include/external so Triton won't have to live with a copy of some CUDA/Hip headers anymore.
- `triton-translate` becomes a `triton.tools.aot` Python utility that re-uses functions from the triton.compile sub-module.
This commit is contained in:
Philippe Tillet
2022-09-26 16:38:06 -07:00
committed by GitHub
parent 8bb09f83ee
commit 1e91ed30d0
28 changed files with 509 additions and 31483 deletions

View File

@@ -184,7 +184,6 @@ target_link_libraries(triton
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonDriver
TritonLLVMIR
TritonPTX
${dialect_libs}

View File

@@ -26,35 +26,35 @@ target_link_libraries(triton-opt PRIVATE
mlir_check_all_link_libraries(triton-opt)
add_llvm_executable(triton-translate triton-translate.cpp PARTIAL_SOURCES_INTENDED)
llvm_update_compile_flags(triton-translate)
target_link_libraries(triton-translate PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonLLVMIR
TritonDriver
${dialect_libs}
${conversion_libs}
# tests
TritonTestAnalysis
# add_llvm_executable(triton-translate triton-translate.cpp PARTIAL_SOURCES_INTENDED)
#llvm_update_compile_flags(triton-translate)
# target_link_libraries(triton-translate PRIVATE
# TritonAnalysis
# TritonTransforms
# TritonGPUTransforms
# TritonLLVMIR
# TritonDriver
# ${dialect_libs}
# ${conversion_libs}
# # tests
# TritonTestAnalysis
LLVMCore
LLVMSupport
LLVMOption
LLVMCodeGen
LLVMAsmParser
# LLVMCore
# LLVMSupport
# LLVMOption
# LLVMCodeGen
# LLVMAsmParser
# MLIR core
MLIROptLib
MLIRIR
MLIRPass
MLIRSupport
MLIRTransforms
MLIRExecutionEngine
MLIRMathToLLVM
MLIRTransformUtils
MLIRLLVMToLLVMIRTranslation
MLIRNVVMToLLVMIRTranslation
)
mlir_check_all_link_libraries(triton-translate)
# # MLIR core
# MLIROptLib
# MLIRIR
# MLIRPass
# MLIRSupport
# MLIRTransforms
# MLIRExecutionEngine
# MLIRMathToLLVM
# MLIRTransformUtils
# MLIRLLVMToLLVMIRTranslation
# MLIRNVVMToLLVMIRTranslation
# )
# mlir_check_all_link_libraries(triton-translate)

View File

@@ -1,34 +1,17 @@
#ifndef TRITON_TARGET_PTXTRANSLATION_H
#define TRITON_TARGET_PTXTRANSLATION_H
#include "triton/driver/dispatch.h"
#include <memory>
#include <string>
namespace mlir {
class ModuleOp;
} // namespace mlir
namespace llvm {
class Module;
} // namespace llvm
namespace triton {
template <CUdevice_attribute attr> int cuGetInfo(CUdevice device) {
int res;
driver::dispatch::cuDeviceGetAttribute(&res, attr, device);
return res;
}
void getCuCCAndVersionFromDevice(uint64_t device, int *cc, int *version,
std::string *ptxasPath);
// Translate TritonGPU IR to PTX code.
std::tuple<std::string, // ptx code
size_t, // PTX cc
int, // PTX version
std::string // ptxas path
>
translateTritonGPUToPTX(mlir::ModuleOp module, uint64_t device);
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version);
} // namespace triton

View File

@@ -1,376 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_DISPATCH_H_
#define _TRITON_DRIVER_DISPATCH_H_
#include <dlfcn.h>
#include <type_traits>
// CUDA Backend
#include "triton/external/CUDA/cuda.h"
#include "triton/external/CUDA/nvml.h"
//// HIP backend
//#define __HIP_PLATFORM_AMD__
#include "triton/external/hip.h"
// Exceptions
#include <iostream>
#include <stdexcept>
namespace llvm {
class PassRegistry;
class Module;
} // namespace llvm
namespace triton {
namespace driver {
class cu_context;
template <class T> void check(T) {}
void check(CUresult err);
void check(hipError_t err);
class dispatch {
protected:
template <class F> struct return_type;
template <class R, class... A> struct return_type<R (*)(A...)> {
typedef R type;
};
typedef bool (*f_init_t)();
template <f_init_t initializer, typename FunPtrT, typename... Args>
static typename return_type<FunPtrT>::type
f_impl(void *&lib_h, FunPtrT, void *&cache, const char *name, Args... args) {
initializer();
if (cache == nullptr) {
cache = dlsym(lib_h, name);
if (cache == 0) {
#ifdef __EXCEPTIONS
throw std::runtime_error("dlsym unable to load function");
#else
std::cerr << "Triton: dlsym unable to load function `" << name << "`"
<< std::endl;
std::abort();
#endif
}
}
FunPtrT fptr;
*reinterpret_cast<void **>(&fptr) = cache;
typename return_type<FunPtrT>::type res = (*fptr)(args...);
check(res);
return res;
}
public:
static void release();
// Nvidia
static bool nvmlinit();
static bool cuinit();
// AMD
static bool hipinit();
/* ------------------- *
* CUDA
* ------------------- */
// context management
static CUresult cuInit(unsigned int Flags);
static CUresult cuCtxDestroy_v2(CUcontext ctx);
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags,
CUdevice dev);
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
static CUresult cuCtxGetDevice(CUdevice *result);
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext,
unsigned int flags);
static CUresult cuDriverGetVersion(int *driverVersion);
// device management
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib,
CUdevice dev);
static CUresult cuDeviceGetCount(int *count);
// link management
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type,
void *data, size_t size, const char *name,
unsigned int numOptions,
CUjit_option *options, void **optionValues);
static CUresult cuLinkCreate_v2(unsigned int numOptions,
CUjit_option *options, void **optionValues,
CUlinkState *stateOut);
static CUresult cuLinkComplete(CUlinkState state, void **cubinOut,
size_t *sizeOut);
static CUresult cuLinkDestroy(CUlinkState state);
// module management
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t *bytes,
CUmodule hmod, const char *name);
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
static CUresult cuModuleLoadData(CUmodule *module, const void *image);
static CUresult cuModuleUnload(CUmodule hmod);
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image,
unsigned int numOptions,
CUjit_option *options,
void **optionValues);
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod,
const char *name);
// stream management
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
static CUresult cuStreamSynchronize(CUstream hStream);
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext *pctx);
static CUresult cuStreamDestroy_v2(CUstream hStream);
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX,
unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY,
unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream,
void **kernelParams, void **extra);
// function management
static CUresult cuFuncGetAttribute(int *pi, CUfunction_attribute attrib,
CUfunction hfunc);
static CUresult cuFuncSetAttribute(CUfunction hfunc,
CUfunction_attribute attrib, int value);
static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config);
// memory management
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
static CUresult cuPointerGetAttribute(void *data,
CUpointer_attribute attribute,
CUdeviceptr ptr);
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N,
CUstream stream);
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice,
size_t ByteCount);
static CUresult cuMemFree_v2(CUdeviceptr dptr);
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice,
size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice,
const void *srcHost, size_t ByteCount,
CUstream hStream);
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost,
size_t ByteCount);
// event management
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart,
CUevent hEnd);
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
static CUresult cuEventDestroy_v2(CUevent hEvent);
/* ------------------- *
* NVML
* ------------------- */
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2(const char *pciBusId,
nvmlDevice_t *device);
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device,
nvmlClockType_t type,
unsigned int *clock);
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device,
nvmlClockType_t type,
unsigned int *clock);
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device,
unsigned int mem_clock,
unsigned int sm_clock);
/* ------------------- *
* HIP
* ------------------- */
// context management
static hipError_t hipInit(unsigned int Flags);
static hipError_t hipCtxDestroy(hipCtx_t ctx);
static hipError_t hipCtxCreate(hipCtx_t *pctx, unsigned int flags,
hipDevice_t dev);
static hipError_t hipCtxPushCurrent(hipCtx_t ctx);
static hipError_t hipCtxPopCurrent(hipCtx_t *pctx);
static hipError_t hipCtxGetDevice(hipDevice_t *result);
static hipError_t hipCtxEnablePeerAccess(hipCtx_t peerContext,
unsigned int flags);
static hipError_t hipDriverGetVersion(int *driverVersion);
// device management
static hipError_t hipGetDevice(hipDevice_t *device, int ordinal);
static hipError_t hipDeviceGetName(char *name, int len, hipDevice_t dev);
static hipError_t hipDeviceGetPCIBusId(char *id, int len, hipDevice_t dev);
static hipError_t hipDeviceGetAttribute(int *pi, hipDeviceAttribute_t attrib,
hipDevice_t dev);
static hipError_t hipGetDeviceCount(int *count);
// module management
static hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t *bytes,
hipModule_t hmod, const char *name);
static hipError_t hipModuleLoad(hipModule_t *module, const char *fname);
static hipError_t hipModuleLoadData(hipModule_t *module, const void *image);
static hipError_t hipModuleUnload(hipModule_t hmod);
static hipError_t hipModuleLoadDataEx(hipModule_t *module, const void *image,
unsigned int numOptions,
hipJitOption *options,
void **optionValues);
static hipError_t hipModuleGetFunction(hipFunction_t *hfunc, hipModule_t hmod,
const char *name);
// stream management
static hipError_t hipStreamCreate(hipStream_t *phStream, unsigned int Flags);
static hipError_t hipStreamSynchronize(hipStream_t hStream);
static hipError_t hipStreamDestroy(hipStream_t hStream);
static hipError_t
hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX,
unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY,
unsigned int blockDimZ, unsigned int sharedMemBytes,
hipStream_t hStream, void **kernelParams, void **extra);
// function management
static hipError_t hipFuncGetAttributes(hipFuncAttributes *attrib,
void *hfunc);
static hipError_t hipFuncSetAttribute(hipFunction_t hfunc,
hipFuncAttribute attrib, int value);
static hipError_t hipFuncSetCacheConfig(hipFunction_t hfunc,
hipFuncCache_t config);
// memory management
static hipError_t hipMalloc(hipDeviceptr_t *dptr, size_t bytesize);
static hipError_t hipPointerGetAttribute(void *data,
CUpointer_attribute attribute,
hipDeviceptr_t ptr);
static hipError_t hipMemsetD8Async(hipDeviceptr_t dst, unsigned char x,
size_t N, hipStream_t stream);
static hipError_t hipMemcpyDtoH(void *dstHost, hipDeviceptr_t srcDevice,
size_t ByteCount);
static hipError_t hipFree(hipDeviceptr_t dptr);
static hipError_t hipMemcpyDtoHAsync(void *dstHost, hipDeviceptr_t srcDevice,
size_t ByteCount, hipStream_t hStream);
static hipError_t hipMemcpyHtoDAsync(hipDeviceptr_t dstDevice,
const void *srcHost, size_t ByteCount,
hipStream_t hStream);
static hipError_t hipMemcpyHtoD(hipDeviceptr_t dstDevice, const void *srcHost,
size_t ByteCount);
// event management
static hipError_t hipEventCreate(hipEvent_t *phEvent, unsigned int Flags);
static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart,
hipEvent_t hEnd);
static hipError_t hipEventRecord(hipEvent_t hEvent, hipStream_t hStream);
static hipError_t hipEventDestroy(hipEvent_t hEvent);
private:
// Libraries
static void *cuda_;
static void *nvml_;
static void *hip_;
/* ------------------- *
* CUDA
* ------------------- */
// context management
static void *cuCtxGetCurrent_;
static void *cuCtxSetCurrent_;
static void *cuCtxDestroy_v2_;
static void *cuCtxCreate_v2_;
static void *cuCtxGetDevice_;
static void *cuCtxPushCurrent_v2_;
static void *cuCtxPopCurrent_v2_;
static void *cuCtxEnablePeerAccess_;
static void *cuDriverGetVersion_;
static void *cuInit_;
// device management
static void *cuDeviceGet_;
static void *cuDeviceGetName_;
static void *cuDeviceGetPCIBusId_;
static void *cuDeviceGetAttribute_;
static void *cuDeviceGetCount_;
// link management
static void *cuLinkAddData_v2_;
static void *cuLinkCreate_v2_;
static void *cuLinkDestroy_;
static void *cuLinkComplete_;
// module management
static void *cuModuleGetGlobal_v2_;
static void *cuModuleLoad_;
static void *cuModuleUnload_;
static void *cuModuleLoadDataEx_;
static void *cuModuleLoadData_;
static void *cuModuleGetFunction_;
// stream management
static void *cuStreamCreate_;
static void *cuStreamSynchronize_;
static void *cuStreamDestroy_v2_;
static void *cuStreamGetCtx_;
static void *cuLaunchKernel_;
// function management
static void *cuFuncGetAttribute_;
static void *cuFuncSetAttribute_;
static void *cuFuncSetCacheConfig_;
// memory management
static void *cuMemcpyDtoH_v2_;
static void *cuMemFree_v2_;
static void *cuMemcpyDtoHAsync_v2_;
static void *cuMemcpyHtoDAsync_v2_;
static void *cuMemcpyHtoD_v2_;
static void *cuMemAlloc_v2_;
static void *cuMemsetD8Async_;
static void *cuPointerGetAttribute_;
// event management
static void *cuEventCreate_;
static void *cuEventElapsedTime_;
static void *cuEventRecord_;
static void *cuEventDestroy_v2_;
/* ------------------- *
* NVML
* ------------------- */
static void *nvmlInit_v2_;
static void *nvmlDeviceGetHandleByPciBusId_v2_;
static void *nvmlDeviceGetClockInfo_;
static void *nvmlDeviceGetMaxClockInfo_;
static void *nvmlDeviceSetApplicationsClocks_;
/* ------------------- *
* HIP
* ------------------- */
// context management
static void *hipInit_;
static void *hipCtxDestroy_;
static void *hipCtxCreate_;
static void *hipCtxPushCurrent_;
static void *hipCtxPopCurrent_;
static void *hipCtxGetDevice_;
static void *hipCtxEnablePeerAccess_;
static void *hipDriverGetVersion_;
// device management
static void *hipGetDevice_;
static void *hipDeviceGetName_;
static void *hipDeviceGetPCIBusId_;
static void *hipDeviceGetAttribute_;
static void *hipGetDeviceCount_;
// module management
static void *hipModuleGetGlobal_;
static void *hipModuleLoad_;
static void *hipModuleLoadData_;
static void *hipModuleUnload_;
static void *hipModuleLoadDataEx_;
static void *hipModuleGetFunction_;
// stream management
static void *hipStreamCreate_;
static void *hipStreamSynchronize_;
static void *hipStreamDestroy_;
static void *hipModuleLaunchKernel_;
;
// function management
static void *hipFuncGetAttributes_;
static void *hipFuncSetAttribute_;
static void *hipFuncSetCacheConfig_;
// memory management
static void *hipMalloc_;
static void *hipPointerGetAttribute_;
static void *hipMemsetD8Async_;
static void *hipMemcpyDtoH_;
static void *hipFree_;
static void *hipMemcpyDtoHAsync_;
static void *hipMemcpyHtoDAsync_;
static void *hipMemcpyHtoD_;
// event management
static void *hipEventCreate_;
static void *hipEventElapsedTime_;
static void *hipEventRecord_;
static void *hipEventDestroy_;
};
} // namespace driver
} // namespace triton
#endif

View File

@@ -1,254 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_ERROR_H_
#define _TRITON_DRIVER_ERROR_H_
#include "triton/driver/dispatch.h"
#include <exception>
namespace triton {
namespace driver {
namespace exception {
namespace nvrtc {
#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) \
class name : public std::exception { \
public: \
const char *what() const throw() override { return "NVRTC: Error- " msg; } \
}
TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory, "out of memory");
TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure,
"program creation failure");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_input, "invalid input");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_program, "invalid program");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_option, "invalid option");
TRITON_CREATE_NVRTC_EXCEPTION(compilation, "compilation");
TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure,
"builtin operation failure");
TRITON_CREATE_NVRTC_EXCEPTION(unknown_error, "unknown error");
#undef TRITON_CREATE_NVRTC_EXCEPTION
} // namespace nvrtc
namespace cuda {
class base : public std::exception {};
#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) \
class name : public base { \
public: \
const char *what() const throw() override { return "CUDA: Error- " msg; } \
}
TRITON_CREATE_CUDA_EXCEPTION(invalid_value, "invalid value");
TRITON_CREATE_CUDA_EXCEPTION(out_of_memory, "out of memory");
TRITON_CREATE_CUDA_EXCEPTION(not_initialized, "not initialized");
TRITON_CREATE_CUDA_EXCEPTION(deinitialized, "deinitialized");
TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled, "profiler disabled");
TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized,
"profiler not initialized");
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started,
"profiler already started");
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped,
"profiler already stopped");
TRITON_CREATE_CUDA_EXCEPTION(no_device, "no device");
TRITON_CREATE_CUDA_EXCEPTION(invalid_device, "invalid device");
TRITON_CREATE_CUDA_EXCEPTION(invalid_image, "invalid image");
TRITON_CREATE_CUDA_EXCEPTION(invalid_context, "invalid context");
TRITON_CREATE_CUDA_EXCEPTION(context_already_current,
"context already current");
TRITON_CREATE_CUDA_EXCEPTION(map_failed, "map failed");
TRITON_CREATE_CUDA_EXCEPTION(unmap_failed, "unmap failed");
TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped, "array is mapped");
TRITON_CREATE_CUDA_EXCEPTION(already_mapped, "already mapped");
TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu, "no binary for gpu");
TRITON_CREATE_CUDA_EXCEPTION(already_acquired, "already acquired");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped, "not mapped");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array, "not mapped as array");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer, "not mapped as pointer");
TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable, "ecc uncorrectable");
TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit, "unsupported limit");
TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use, "context already in use");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported,
"peer access unsupported");
TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx, "invalid ptx");
TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context,
"invalid graphics context");
TRITON_CREATE_CUDA_EXCEPTION(invalid_source, "invalid source");
TRITON_CREATE_CUDA_EXCEPTION(file_not_found, "file not found");
TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found,
"shared object symbol not found");
TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed,
"shared object init failed");
TRITON_CREATE_CUDA_EXCEPTION(operating_system, "operating system");
TRITON_CREATE_CUDA_EXCEPTION(invalid_handle, "invalid handle");
TRITON_CREATE_CUDA_EXCEPTION(not_found, "not found");
TRITON_CREATE_CUDA_EXCEPTION(not_ready, "not ready");
TRITON_CREATE_CUDA_EXCEPTION(illegal_address, "illegal address");
TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources,
"launch out of resources");
TRITON_CREATE_CUDA_EXCEPTION(launch_timeout, "launch timeout");
TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing,
"launch incompatible texturing");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled,
"peer access already enabled");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled,
"peer access not enabled");
TRITON_CREATE_CUDA_EXCEPTION(primary_context_active, "primary context active");
TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed, "context is destroyed");
TRITON_CREATE_CUDA_EXCEPTION(assert_error, "assert");
TRITON_CREATE_CUDA_EXCEPTION(too_many_peers, "too many peers");
TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered,
"host memory already registered");
TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered,
"hot memory not registered");
TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error, "hardware stack error");
TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction, "illegal instruction");
TRITON_CREATE_CUDA_EXCEPTION(misaligned_address, "misaligned address");
TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space, "invalid address space");
TRITON_CREATE_CUDA_EXCEPTION(invalid_pc, "invalid pc");
TRITON_CREATE_CUDA_EXCEPTION(launch_failed, "launch failed");
TRITON_CREATE_CUDA_EXCEPTION(not_permitted, "not permitted");
TRITON_CREATE_CUDA_EXCEPTION(not_supported, "not supported");
TRITON_CREATE_CUDA_EXCEPTION(unknown, "unknown");
#undef TRITON_CREATE_CUDA_EXCEPTION
} // namespace cuda
namespace cublas {
class base : public std::exception {};
#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) \
class name : public base { \
public: \
const char *what() const throw() override { \
return "CUBLAS: Error- " msg; \
} \
}
TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized, "not initialized");
TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed, "alloc failed");
TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value, "invalid value");
TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch, "arch mismatch");
TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error, "mapping error");
TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed, "execution failed");
TRITON_CREATE_CUBLAS_EXCEPTION(internal_error, "internal error");
TRITON_CREATE_CUBLAS_EXCEPTION(not_supported, "not supported");
TRITON_CREATE_CUBLAS_EXCEPTION(license_error, "license error");
TRITON_CREATE_CUBLAS_EXCEPTION(unknown, "unknown");
#undef TRITON_CREATE_CUBLAS_EXCEPTION
} // namespace cublas
namespace cudnn {
#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) \
class name : public std::exception { \
public: \
const char *what() const throw() override { return "CUDNN: Error- " msg; } \
}
TRITON_CREATE_CUDNN_EXCEPTION(not_initialized, "not initialized");
TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed, "allocation failed");
TRITON_CREATE_CUDNN_EXCEPTION(bad_param, "bad param");
TRITON_CREATE_CUDNN_EXCEPTION(internal_error, "internal error");
TRITON_CREATE_CUDNN_EXCEPTION(invalid_value, "invalid value");
TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch, "arch mismatch");
TRITON_CREATE_CUDNN_EXCEPTION(mapping_error, "mapping error");
TRITON_CREATE_CUDNN_EXCEPTION(execution_failed, "execution failed");
TRITON_CREATE_CUDNN_EXCEPTION(not_supported, "not supported");
TRITON_CREATE_CUDNN_EXCEPTION(license_error, "license error");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing,
"prerequisite missing");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress, "runtime in progress");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow, "runtime fp overflow");
} // namespace cudnn
namespace hip {
class base : public std::exception {};
#define TRITON_CREATE_HIP_EXCEPTION(name, msg) \
class name : public base { \
public: \
const char *what() const throw() override { return "HIP: Error- " msg; } \
}
TRITON_CREATE_HIP_EXCEPTION(invalid_value, "invalid value");
TRITON_CREATE_HIP_EXCEPTION(out_of_memory, "out of memory");
TRITON_CREATE_HIP_EXCEPTION(not_initialized, "not initialized");
TRITON_CREATE_HIP_EXCEPTION(deinitialized, "deinitialized");
TRITON_CREATE_HIP_EXCEPTION(profiler_disabled, "profiler disabled");
TRITON_CREATE_HIP_EXCEPTION(profiler_not_initialized,
"profiler not initialized");
TRITON_CREATE_HIP_EXCEPTION(profiler_already_started,
"profiler already started");
TRITON_CREATE_HIP_EXCEPTION(profiler_already_stopped,
"profiler already stopped");
TRITON_CREATE_HIP_EXCEPTION(no_device, "no device");
TRITON_CREATE_HIP_EXCEPTION(invalid_device, "invalid device");
TRITON_CREATE_HIP_EXCEPTION(invalid_image, "invalid image");
TRITON_CREATE_HIP_EXCEPTION(invalid_context, "invalid context");
TRITON_CREATE_HIP_EXCEPTION(context_already_current, "context already current");
TRITON_CREATE_HIP_EXCEPTION(map_failed, "map failed");
TRITON_CREATE_HIP_EXCEPTION(unmap_failed, "unmap failed");
TRITON_CREATE_HIP_EXCEPTION(array_is_mapped, "array is mapped");
TRITON_CREATE_HIP_EXCEPTION(already_mapped, "already mapped");
TRITON_CREATE_HIP_EXCEPTION(no_binary_for_gpu, "no binary for gpu");
TRITON_CREATE_HIP_EXCEPTION(already_acquired, "already acquired");
TRITON_CREATE_HIP_EXCEPTION(not_mapped, "not mapped");
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_array, "not mapped as array");
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_pointer, "not mapped as pointer");
TRITON_CREATE_HIP_EXCEPTION(ecc_uncorrectable, "ecc uncorrectable");
TRITON_CREATE_HIP_EXCEPTION(unsupported_limit, "unsupported limit");
TRITON_CREATE_HIP_EXCEPTION(context_already_in_use, "context already in use");
TRITON_CREATE_HIP_EXCEPTION(peer_access_unsupported, "peer access unsupported");
TRITON_CREATE_HIP_EXCEPTION(invalid_ptx, "invalid ptx");
TRITON_CREATE_HIP_EXCEPTION(invalid_graphics_context,
"invalid graphics context");
TRITON_CREATE_HIP_EXCEPTION(invalid_source, "invalid source");
TRITON_CREATE_HIP_EXCEPTION(file_not_found, "file not found");
TRITON_CREATE_HIP_EXCEPTION(shared_object_symbol_not_found,
"shared object symbol not found");
TRITON_CREATE_HIP_EXCEPTION(shared_object_init_failed,
"shared object init failed");
TRITON_CREATE_HIP_EXCEPTION(operating_system, "operating system");
TRITON_CREATE_HIP_EXCEPTION(invalid_handle, "invalid handle");
TRITON_CREATE_HIP_EXCEPTION(not_found, "not found");
TRITON_CREATE_HIP_EXCEPTION(not_ready, "not ready");
TRITON_CREATE_HIP_EXCEPTION(illegal_address, "illegal address");
TRITON_CREATE_HIP_EXCEPTION(launch_out_of_resources, "launch out of resources");
TRITON_CREATE_HIP_EXCEPTION(launch_timeout, "launch timeout");
TRITON_CREATE_HIP_EXCEPTION(launch_incompatible_texturing,
"launch incompatible texturing");
TRITON_CREATE_HIP_EXCEPTION(peer_access_already_enabled,
"peer access already enabled");
TRITON_CREATE_HIP_EXCEPTION(peer_access_not_enabled, "peer access not enabled");
TRITON_CREATE_HIP_EXCEPTION(primary_context_active, "primary context active");
TRITON_CREATE_HIP_EXCEPTION(context_is_destroyed, "context is destroyed");
TRITON_CREATE_HIP_EXCEPTION(assert_error, "assert");
TRITON_CREATE_HIP_EXCEPTION(too_many_peers, "too many peers");
TRITON_CREATE_HIP_EXCEPTION(host_memory_already_registered,
"host memory already registered");
TRITON_CREATE_HIP_EXCEPTION(host_memory_not_registered,
"hot memory not registered");
TRITON_CREATE_HIP_EXCEPTION(hardware_stack_error, "hardware stack error");
TRITON_CREATE_HIP_EXCEPTION(illegal_instruction, "illegal instruction");
TRITON_CREATE_HIP_EXCEPTION(misaligned_address, "misaligned address");
TRITON_CREATE_HIP_EXCEPTION(invalid_address_space, "invalid address space");
TRITON_CREATE_HIP_EXCEPTION(invalid_pc, "invalid pc");
TRITON_CREATE_HIP_EXCEPTION(launch_failed, "launch failed");
TRITON_CREATE_HIP_EXCEPTION(not_permitted, "not permitted");
TRITON_CREATE_HIP_EXCEPTION(not_supported, "not supported");
TRITON_CREATE_HIP_EXCEPTION(invalid_symbol, "invalid symbol");
TRITON_CREATE_HIP_EXCEPTION(unknown, "unknown");
#undef TRITON_CREATE_CUDA_EXCEPTION
} // namespace hip
} // namespace exception
} // namespace driver
} // namespace triton
#endif

View File

@@ -1,22 +0,0 @@
#include "triton/external/CUDA/cuda.h"
#include "triton/external/hip.h"
#include <string>
namespace llvm {
class Module;
}
namespace triton {
namespace driver {
void init_llvm();
std::string path_to_ptxas(int &version);
std::string llir_to_ptx(llvm::Module *module, int cc, int version);
std::string ptx_to_cubin(const std::string &ptx, const std::string &ptxas_path,
int cc);
CUmodule ptx_to_cumodule(const std::string &ptx, int cc);
std::string llir_to_amdgpu(llvm::Module *module, const std::string &proc);
hipModule_t amdgpu_to_hipmodule(const std::string &path);
} // namespace driver
} // namespace triton

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,293 +0,0 @@
#ifndef __external_hip_h__
#define __external_hip_h__
/*
* @brief hipError_t
* @enum
* @ingroup Enumerations
*/
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
// Ignoring error-code return values from hip APIs is discouraged. On C++17,
// we can make that yield a warning
/*
* @brief hipError_t
* @enum
* @ingroup Enumerations
*/
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
#include <cstddef>
typedef enum hipError_t {
hipSuccess = 0, ///< Successful completion.
hipErrorInvalidValue = 1, ///< One or more of the parameters passed to the API call is NULL
///< or not in an acceptable range.
hipErrorOutOfMemory = 2,
// Deprecated
hipErrorMemoryAllocation = 2, ///< Memory allocation error.
hipErrorNotInitialized = 3,
// Deprecated
hipErrorInitializationError = 3,
hipErrorDeinitialized = 4,
hipErrorProfilerDisabled = 5,
hipErrorProfilerNotInitialized = 6,
hipErrorProfilerAlreadyStarted = 7,
hipErrorProfilerAlreadyStopped = 8,
hipErrorInvalidConfiguration = 9,
hipErrorInvalidPitchValue = 12,
hipErrorInvalidSymbol = 13,
hipErrorInvalidDevicePointer = 17, ///< Invalid Device Pointer
hipErrorInvalidMemcpyDirection = 21, ///< Invalid memory copy direction
hipErrorInsufficientDriver = 35,
hipErrorMissingConfiguration = 52,
hipErrorPriorLaunchFailure = 53,
hipErrorInvalidDeviceFunction = 98,
hipErrorNoDevice = 100, ///< Call to hipGetDeviceCount returned 0 devices
hipErrorInvalidDevice = 101, ///< DeviceID must be in range 0...#compute-devices.
hipErrorInvalidImage = 200,
hipErrorInvalidContext = 201, ///< Produced when input context is invalid.
hipErrorContextAlreadyCurrent = 202,
hipErrorMapFailed = 205,
// Deprecated
hipErrorMapBufferObjectFailed = 205, ///< Produced when the IPC memory attach failed from ROCr.
hipErrorUnmapFailed = 206,
hipErrorArrayIsMapped = 207,
hipErrorAlreadyMapped = 208,
hipErrorNoBinaryForGpu = 209,
hipErrorAlreadyAcquired = 210,
hipErrorNotMapped = 211,
hipErrorNotMappedAsArray = 212,
hipErrorNotMappedAsPointer = 213,
hipErrorECCNotCorrectable = 214,
hipErrorUnsupportedLimit = 215,
hipErrorContextAlreadyInUse = 216,
hipErrorPeerAccessUnsupported = 217,
hipErrorInvalidKernelFile = 218, ///< In CUDA DRV, it is CUDA_ERROR_INVALID_PTX
hipErrorInvalidGraphicsContext = 219,
hipErrorInvalidSource = 300,
hipErrorFileNotFound = 301,
hipErrorSharedObjectSymbolNotFound = 302,
hipErrorSharedObjectInitFailed = 303,
hipErrorOperatingSystem = 304,
hipErrorInvalidHandle = 400,
// Deprecated
hipErrorInvalidResourceHandle = 400, ///< Resource handle (hipEvent_t or hipStream_t) invalid.
hipErrorNotFound = 500,
hipErrorNotReady = 600, ///< Indicates that asynchronous operations enqueued earlier are not
///< ready. This is not actually an error, but is used to distinguish
///< from hipSuccess (which indicates completion). APIs that return
///< this error include hipEventQuery and hipStreamQuery.
hipErrorIllegalAddress = 700,
hipErrorLaunchOutOfResources = 701, ///< Out of resources error.
hipErrorLaunchTimeOut = 702,
hipErrorPeerAccessAlreadyEnabled =
704, ///< Peer access was already enabled from the current device.
hipErrorPeerAccessNotEnabled =
705, ///< Peer access was never enabled from the current device.
hipErrorSetOnActiveProcess = 708,
hipErrorAssert = 710, ///< Produced when the kernel calls assert.
hipErrorHostMemoryAlreadyRegistered =
712, ///< Produced when trying to lock a page-locked memory.
hipErrorHostMemoryNotRegistered =
713, ///< Produced when trying to unlock a non-page-locked memory.
hipErrorLaunchFailure =
719, ///< An exception occurred on the device while executing a kernel.
hipErrorCooperativeLaunchTooLarge =
720, ///< This error indicates that the number of blocks launched per grid for a kernel
///< that was launched via cooperative launch APIs exceeds the maximum number of
///< allowed blocks for the current device
hipErrorNotSupported = 801, ///< Produced when the hip API is not supported/implemented
hipErrorUnknown = 999, //< Unknown error.
// HSA Runtime Error Codes start here.
hipErrorRuntimeMemory = 1052, ///< HSA runtime memory call returned error. Typically not seen
///< in production systems.
hipErrorRuntimeOther = 1053, ///< HSA runtime call other than memory returned error. Typically
///< not seen in production systems.
hipErrorTbd ///< Marker that more error codes are needed.
} hipError_t;
typedef struct ihipCtx_t* hipCtx_t;
// Note many APIs also use integer deviceIds as an alternative to the device pointer:
typedef int hipDevice_t;
typedef enum hipDeviceP2PAttr {
hipDevP2PAttrPerformanceRank = 0,
hipDevP2PAttrAccessSupported,
hipDevP2PAttrNativeAtomicSupported,
hipDevP2PAttrHipArrayAccessSupported
} hipDeviceP2PAttr;
typedef struct ihipStream_t* hipStream_t;
#define hipIpcMemLazyEnablePeerAccess 0
#define HIP_IPC_HANDLE_SIZE 64
typedef struct hipIpcMemHandle_st {
char reserved[HIP_IPC_HANDLE_SIZE];
} hipIpcMemHandle_t;
typedef struct hipIpcEventHandle_st {
char reserved[HIP_IPC_HANDLE_SIZE];
} hipIpcEventHandle_t;
typedef struct ihipModule_t* hipModule_t;
typedef struct ihipModuleSymbol_t* hipFunction_t;
typedef struct hipFuncAttributes {
int binaryVersion;
int cacheModeCA;
size_t constSizeBytes;
size_t localSizeBytes;
int maxDynamicSharedSizeBytes;
int maxThreadsPerBlock;
int numRegs;
int preferredShmemCarveout;
int ptxVersion;
size_t sharedSizeBytes;
} hipFuncAttributes;
typedef struct ihipEvent_t* hipEvent_t;
/*
* @brief hipDeviceAttribute_t
* @enum
* @ingroup Enumerations
*/
typedef enum hipDeviceAttribute_t {
hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per block.
hipDeviceAttributeMaxBlockDimX, ///< Maximum x-dimension of a block.
hipDeviceAttributeMaxBlockDimY, ///< Maximum y-dimension of a block.
hipDeviceAttributeMaxBlockDimZ, ///< Maximum z-dimension of a block.
hipDeviceAttributeMaxGridDimX, ///< Maximum x-dimension of a grid.
hipDeviceAttributeMaxGridDimY, ///< Maximum y-dimension of a grid.
hipDeviceAttributeMaxGridDimZ, ///< Maximum z-dimension of a grid.
hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory available per block in
///< bytes.
hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes.
hipDeviceAttributeWarpSize, ///< Warp size in threads.
hipDeviceAttributeMaxRegistersPerBlock, ///< Maximum number of 32-bit registers available to a
///< thread block. This number is shared by all thread
///< blocks simultaneously resident on a
///< multiprocessor.
hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz.
hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in kilohertz.
hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits.
hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the device.
hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in.
hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device doesn't have L2
///< cache.
hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads per
///< multiprocessor.
hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability version number.
hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability version number.
hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple kernels
///< concurrently.
hipDeviceAttributePciBusId, ///< PCI Bus ID.
hipDeviceAttributePciDeviceId, ///< PCI Device ID.
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory Per
///< Multiprocessor.
hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices.
hipDeviceAttributeIntegrated, ///< iGPU
hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch
hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative launch on multiple devices
hipDeviceAttributeMaxTexture1DWidth, ///< Maximum number of elements in 1D images
hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D images in image elements
hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension height of 2D images in image elements
hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D images in image elements
hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimensions height of 3D images in image elements
hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimensions depth of 3D images in image elements
hipDeviceAttributeHdpMemFlushCntl, ///< Address of the HDP_MEM_COHERENCY_FLUSH_CNTL register
hipDeviceAttributeHdpRegFlushCntl, ///< Address of the HDP_REG_COHERENCY_FLUSH_CNTL register
hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory copies
hipDeviceAttributeTextureAlignment, ///<Alignment requirement for textures
hipDeviceAttributeTexturePitchAlignment, ///<Pitch alignment requirement for 2D texture references bound to pitched memory;
hipDeviceAttributeKernelExecTimeout, ///<Run time limit for kernels executed on the device
hipDeviceAttributeCanMapHostMemory, ///<Device can map host memory into device address space
hipDeviceAttributeEccEnabled, ///<Device has ECC support enabled
hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc, ///< Supports cooperative launch on multiple
///devices with unmatched functions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim, ///< Supports cooperative launch on multiple
///devices with unmatched grid dimensions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim, ///< Supports cooperative launch on multiple
///devices with unmatched block dimensions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem, ///< Supports cooperative launch on multiple
///devices with unmatched shared memories
hipDeviceAttributeAsicRevision, ///< Revision of the GPU in this device
hipDeviceAttributeManagedMemory, ///< Device supports allocating managed memory on this system
hipDeviceAttributeDirectManagedMemAccessFromHost, ///< Host can directly access managed memory on
/// the device without migration
hipDeviceAttributeConcurrentManagedAccess, ///< Device can coherently access managed memory
/// concurrently with the CPU
hipDeviceAttributePageableMemoryAccess, ///< Device supports coherently accessing pageable memory
/// without calling hipHostRegister on it
hipDeviceAttributePageableMemoryAccessUsesHostPageTables, ///< Device accesses pageable memory via
/// the host's page tables
hipDeviceAttributeCanUseStreamWaitValue ///< '1' if Device supports hipStreamWaitValue32() and
///< hipStreamWaitValue64() , '0' otherwise.
} hipDeviceAttribute_t;
typedef void* hipDeviceptr_t;
/*
* @brief hipJitOption
* @enum
* @ingroup Enumerations
*/
typedef enum hipJitOption {
hipJitOptionMaxRegisters = 0,
hipJitOptionThreadsPerBlock,
hipJitOptionWallTime,
hipJitOptionInfoLogBuffer,
hipJitOptionInfoLogBufferSizeBytes,
hipJitOptionErrorLogBuffer,
hipJitOptionErrorLogBufferSizeBytes,
hipJitOptionOptimizationLevel,
hipJitOptionTargetFromContext,
hipJitOptionTarget,
hipJitOptionFallbackStrategy,
hipJitOptionGenerateDebugInfo,
hipJitOptionLogVerbose,
hipJitOptionGenerateLineInfo,
hipJitOptionCacheMode,
hipJitOptionSm3xOpt,
hipJitOptionFastCompile,
hipJitOptionNumOptions
} hipJitOption;
/**
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
*/
typedef enum hipFuncAttribute {
hipFuncAttributeMaxDynamicSharedMemorySize = 8,
hipFuncAttributePreferredSharedMemoryCarveout = 9,
hipFuncAttributeMax
} hipFuncAttribute;
/**
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
*/
typedef enum hipFuncCache_t {
hipFuncCachePreferNone, ///< no preference for shared memory or L1 (default)
hipFuncCachePreferShared, ///< prefer larger shared memory and smaller L1 cache
hipFuncCachePreferL1, ///< prefer larger L1 cache and smaller shared memory
hipFuncCachePreferEqual, ///< prefer equal size L1 cache and shared memory
} hipFuncCache_t;
#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01)
#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02)
#define HIP_LAUNCH_PARAM_END ((void*)0x03)
#endif

View File

@@ -1,57 +0,0 @@
#pragma once
#ifndef _TRITON_TOOLS_BENCH_H_
#define _TRITON_TOOLS_BENCH_H_
#include "triton/driver/device.h"
#include "triton/driver/stream.h"
#include <algorithm>
#include <chrono>
#include <functional>
namespace triton {
namespace tools {
class timer {
typedef std::chrono::high_resolution_clock high_resolution_clock;
typedef std::chrono::nanoseconds nanoseconds;
public:
explicit timer(bool run = false) {
if (run)
start();
}
void start() { _start = high_resolution_clock::now(); }
nanoseconds get() const {
return std::chrono::duration_cast<nanoseconds>(
high_resolution_clock::now() - _start);
}
private:
high_resolution_clock::time_point _start;
};
inline double bench(std::function<void()> const &op, driver::stream *stream,
size_t warmup = 10, size_t repeat = 200) {
timer tmr;
std::vector<size_t> times;
double total_time = 0;
for (size_t i = 0; i < warmup; i++)
op();
stream->synchronize();
tmr.start();
for (size_t i = 0; i < repeat; i++) {
op();
}
stream->synchronize();
return (float)tmr.get().count() / repeat;
// return *std::min_element(times.begin(), times.end());
}
} // namespace tools
} // namespace triton
#endif

View File

@@ -1,68 +0,0 @@
#pragma once
#ifndef _TRITON_TOOLS_THREAD_GRAPH_H_
#define _TRITON_TOOLS_THREAD_GRAPH_H_
#include <iostream>
#include <map>
#include <set>
#include <vector>
namespace triton {
namespace tools {
template <class node_t> class graph {
typedef std::map<node_t, std::set<node_t>> edges_t;
public:
typedef std::map<size_t, std::vector<node_t>> cmap_t;
typedef std::map<node_t, size_t> nmap_t;
private:
void connected_components_impl(node_t x, std::set<node_t> &nodes,
nmap_t *nmap, cmap_t *cmap, int id) const {
if (nmap)
(*nmap)[x] = id;
if (cmap)
(*cmap)[id].push_back(x);
if (nodes.find(x) != nodes.end()) {
nodes.erase(x);
for (const node_t &y : edges_.at(x))
connected_components_impl(y, nodes, nmap, cmap, id);
}
}
public:
void connected_components(cmap_t *cmap, nmap_t *nmap) const {
if (cmap)
cmap->clear();
if (nmap)
nmap->clear();
std::set<node_t> nodes = nodes_;
unsigned id = 0;
while (!nodes.empty()) {
connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++);
}
}
void add_edge(node_t x, node_t y) {
nodes_.insert(x);
nodes_.insert(y);
edges_[x].insert(y);
edges_[y].insert(x);
}
void clear() {
nodes_.clear();
edges_.clear();
}
private:
std::set<node_t> nodes_;
edges_t edges_;
};
} // namespace tools
} // namespace triton
#endif

View File

@@ -1,172 +0,0 @@
/*
Copyright (c) 2011, Micael Hildenborg
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of Micael Hildenborg nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY Micael Hildenborg ''AS IS'' AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL Micael Hildenborg BE LIABLE FOR ANY
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/*
Contributors:
Gustav
Several members in the gamedev.se forum.
Gregory Petrosyan
*/
#ifndef _TRITON_TOOLS_SHA1_HPP_
#define _TRITON_TOOLS_SHA1_HPP_
namespace sha1 {
namespace // local
{
// Rotate an integer value to left.
inline unsigned int rol(const unsigned int value, const unsigned int steps) {
return ((value << steps) | (value >> (32 - steps)));
}
// Sets the first 16 integers in the buffert to zero.
// Used for clearing the W buffert.
inline void clearWBuffert(unsigned int *buffert) {
for (int pos = 16; --pos >= 0;) {
buffert[pos] = 0;
}
}
inline void innerHash(unsigned int *result, unsigned int *w) {
unsigned int a = result[0];
unsigned int b = result[1];
unsigned int c = result[2];
unsigned int d = result[3];
unsigned int e = result[4];
int round = 0;
#define sha1macro(func, val) \
{ \
const unsigned int t = rol(a, 5) + (func) + e + val + w[round]; \
e = d; \
d = c; \
c = rol(b, 30); \
b = a; \
a = t; \
}
while (round < 16) {
sha1macro((b & c) | (~b & d), 0x5a827999)++ round;
}
while (round < 20) {
w[round] =
rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
sha1macro((b & c) | (~b & d), 0x5a827999)++ round;
}
while (round < 40) {
w[round] =
rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
sha1macro(b ^ c ^ d, 0x6ed9eba1)++ round;
}
while (round < 60) {
w[round] =
rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
sha1macro((b & c) | (b & d) | (c & d), 0x8f1bbcdc)++ round;
}
while (round < 80) {
w[round] =
rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1);
sha1macro(b ^ c ^ d, 0xca62c1d6)++ round;
}
#undef sha1macro
result[0] += a;
result[1] += b;
result[2] += c;
result[3] += d;
result[4] += e;
}
} // namespace
inline void calc(const void *src, const int bytelength, unsigned char *hash) {
// Init the result array.
unsigned int result[5] = {0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476,
0xc3d2e1f0};
// Cast the void src pointer to be the byte array we can work with.
const unsigned char *sarray = (const unsigned char *)src;
// The reusable round buffer
unsigned int w[80];
// Loop through all complete 64byte blocks.
const int endOfFullBlocks = bytelength - 64;
int endCurrentBlock;
int currentBlock = 0;
while (currentBlock <= endOfFullBlocks) {
endCurrentBlock = currentBlock + 64;
// Init the round buffer with the 64 byte block data.
for (int roundPos = 0; currentBlock < endCurrentBlock; currentBlock += 4) {
// This line will swap endian on big endian and keep endian on little
// endian.
w[roundPos++] = (unsigned int)sarray[currentBlock + 3] |
(((unsigned int)sarray[currentBlock + 2]) << 8) |
(((unsigned int)sarray[currentBlock + 1]) << 16) |
(((unsigned int)sarray[currentBlock]) << 24);
}
innerHash(result, w);
}
// Handle the last and not full 64 byte block if existing.
endCurrentBlock = bytelength - currentBlock;
clearWBuffert(w);
int lastBlockBytes = 0;
for (; lastBlockBytes < endCurrentBlock; ++lastBlockBytes) {
w[lastBlockBytes >> 2] |=
(unsigned int)sarray[lastBlockBytes + currentBlock]
<< ((3 - (lastBlockBytes & 3)) << 3);
}
w[lastBlockBytes >> 2] |= 0x80 << ((3 - (lastBlockBytes & 3)) << 3);
if (endCurrentBlock >= 56) {
innerHash(result, w);
clearWBuffert(w);
}
w[15] = bytelength << 3;
innerHash(result, w);
// Store hash in result pointer, and make sure we get in in the correct order
// on both endian models.
for (int hashByte = 20; --hashByte >= 0;) {
hash[hashByte] =
(result[hashByte >> 2] >> (((3 - hashByte) & 0x3) << 3)) & 0xff;
}
}
inline void toHexString(const unsigned char *hash, char *hexstring) {
const char hexDigits[] = {"0123456789abcdef"};
for (int hashByte = 20; --hashByte >= 0;) {
hexstring[hashByte << 1] = hexDigits[(hash[hashByte] >> 4) & 0xf];
hexstring[(hashByte << 1) + 1] = hexDigits[hash[hashByte] & 0xf];
}
hexstring[40] = 0;
}
} // namespace sha1
#endif

View File

@@ -1,42 +0,0 @@
#ifndef TRITON_TOOLS_SYS_EXEC_HPP
#define TRITON_TOOLS_SYS_EXEC_HPP
#include <cstdio>
#include <iostream>
#include <memory>
#include <stdexcept>
#include <string>
namespace triton {
namespace tools {
#ifdef _WIN32
#define popen _popen
#define pclose _pclose
#endif
#ifndef WEXITSTATUS
#define WEXITSTATUS(stat_val) ((unsigned)(stat_val)&255)
#endif
int exec(const std::string &cmd, std::string &result) {
char buffer[128];
FILE *pipe = popen(cmd.c_str(), "r");
if (!pipe)
return 0;
result.clear();
try {
while (fgets(buffer, sizeof buffer, pipe) != NULL)
result += buffer;
} catch (...) {
pclose(pipe);
return 0;
}
int status = pclose(pipe);
return WEXITSTATUS(status);
}
} // namespace tools
} // namespace triton
#endif

View File

@@ -1,70 +0,0 @@
/*
* Copyright (c) 2015, PHILIPPE TILLET. All rights reserved.
*
* This file is part of ISAAC.
*
* ISAAC is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
* MA 02110-1301 USA
*/
#ifndef TDL_TOOLS_SYS_MKDIR_HPP
#define TDL_TOOLS_SYS_MKDIR_HPP
#include <cstdlib>
#include <cstring>
#include <errno.h>
#include <string>
#include <sys/stat.h>
#if defined(_WIN32)
#include <direct.h>
#endif
namespace triton {
namespace tools {
inline int mkdir(std::string const &path) {
#if defined(_WIN32)
return _mkdir(path.c_str());
#else
return ::mkdir(path.c_str(), 0777);
#endif
}
inline int mkpath(std::string const &path) {
int status = 0;
size_t pp = 0;
size_t sp;
while ((sp = path.find('/', pp)) != std::string::npos) {
if (sp != pp) {
status = mkdir(path.substr(0, sp));
}
pp = sp + 1;
}
return (status == 0 || errno == EEXIST) ? 0 : -1;
}
inline int mtime(std::string const &path) {
struct stat st;
if (stat(path.c_str(), &st) != 0)
return 0;
return st.st_mtime;
}
} // namespace tools
} // namespace triton
#endif

View File

@@ -1,81 +0,0 @@
#pragma once
#ifndef _TRITON_TOOLS_THREAD_POOL_H_
#define _TRITON_TOOLS_THREAD_POOL_H_
#include <condition_variable>
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <stdexcept>
#include <thread>
#include <vector>
class ThreadPool {
public:
ThreadPool(size_t threads) : stop(false) {
for (size_t i = 0; i < threads; ++i)
workers.emplace_back([this] {
for (;;) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(
lock, [this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
});
}
template <class F, class... Args>
auto enqueue(F &&f, Args &&...args)
-> std::future<typename std::result_of<F(Args...)>::type> {
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if (stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;
}
~ThreadPool() {
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread &worker : workers)
worker.join();
}
private:
// need to keep track of threads so we can join them
std::vector<std::thread> workers;
// the task queue
std::queue<std::function<void()>> tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
#endif

View File

@@ -1,5 +1,4 @@
# add_subdirectory(codegen)
add_subdirectory(driver)
add_subdirectory(Analysis)
add_subdirectory(Conversion)
add_subdirectory(Dialect)

View File

@@ -13,7 +13,6 @@
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/driver/llvm.h"
#include "triton/tools/sys/getenv.hpp"
#include "llvm/IR/Constants.h"
@@ -99,7 +98,6 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
}
// Initialize LLVM targets.
::triton::driver::init_llvm();
mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
auto optPipeline = mlir::makeOptimizingTransformer(

View File

@@ -11,31 +11,129 @@
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/driver/dispatch.h"
#include "triton/driver/llvm.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include <regex>
namespace triton {
void getCuCCAndVersionFromDevice(uint64_t device, int *cc, int *version,
std::string *ptxasPath) {
CUdevice dev = (CUdevice)device;
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
*cc = major * 10 + minor;
*ptxasPath = driver::path_to_ptxas(*version); // assign version
extern "C" {
int set_curterm(char *nterm) { return 0; }
int del_curterm(char *nterm) { return 0; }
int tigetnum(char *capname) { return 0; }
int setupterm(char *term, int fildes, int *errret) { return 0; }
}
std::tuple<std::string, size_t, int, std::string>
translateTritonGPUToPTX(mlir::ModuleOp module, uint64_t device) {
int cc;
int version;
std::string ptxasPath;
getCuCCAndVersionFromDevice(device, &cc, &version, &ptxasPath);
static void init_llvm() {
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
}
llvm::LLVMContext ctx;
auto llModule = mlir::triton::translateTritonGPUToLLVMIR(&ctx, module);
auto ptxCode = driver::llir_to_ptx(llModule.get(), cc, version);
return std::make_tuple(ptxCode, cc, version, ptxasPath);
static bool find_and_replace(std::string &str, const std::string &begin,
const std::string &end,
const std::string &target) {
size_t start_replace = str.find(begin);
if (start_replace == std::string::npos)
return false;
size_t end_replace = str.find(end, start_replace);
if (end_replace == std::string::npos)
return false;
str.replace(start_replace, end_replace + 1 - start_replace, target);
return true;
}
static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
// LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75;
int max_nvvm_ptx = 74;
// options
auto options = llvm::cl::getRegisteredOptions();
auto *short_ptr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(short_ptr);
short_ptr->setValue(true);
// compute capability
std::string sm = "sm_" + std::to_string(capability);
// max PTX version
int ptx_major = ptx / 10;
int ptx_minor = ptx % 10;
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(std::min(capability, max_nvvm_cc));
std::string layout = "";
std::string features = "";
// std::string features = "+ptx" + std::to_string(std::min(ptx,
// max_nvvm_ptx));
init_llvm();
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(*module);
// module->print(llvm::outs(), nullptr);
// create machine
module->setTargetTriple(triple);
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if (layout.empty())
module->setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
// emit machine code
for (llvm::Function &f : module->functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// emit
machine->addPassesToEmitFile(pass, stream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(*module);
// post-process
std::string result(buffer.begin(), buffer.end());
find_and_replace(result, ".version", "\n",
".version " + std::to_string(ptx_major) + "." +
std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
;
while (find_and_replace(result, "\t// end inline asm", "\n", ""))
;
return result;
}
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
auto ptxCode = llir_to_ptx(&module, cc, version);
return ptxCode;
}
} // namespace triton

View File

@@ -1,5 +0,0 @@
add_library(TritonDriver
dispatch.cc
error.cc
llvm.cc
)

View File

@@ -1,395 +0,0 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "triton/driver/dispatch.h"
namespace triton {
namespace driver {
// Helpers for function definition
#define DEFINE0(init, hlib, ret, fname) \
ret dispatch::fname() { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname); \
} \
void *dispatch::fname##_;
#define DEFINE1(init, hlib, ret, fname, t1) \
ret dispatch::fname(t1 a) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a); \
} \
void *dispatch::fname##_;
#define DEFINE2(init, hlib, ret, fname, t1, t2) \
ret dispatch::fname(t1 a, t2 b) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b); \
} \
void *dispatch::fname##_;
#define DEFINE3(init, hlib, ret, fname, t1, t2, t3) \
ret dispatch::fname(t1 a, t2 b, t3 c) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c); \
} \
void *dispatch::fname##_;
#define DEFINE4(init, hlib, ret, fname, t1, t2, t3, t4) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d); \
} \
void *dispatch::fname##_;
#define DEFINE5(init, hlib, ret, fname, t1, t2, t3, t4, t5) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e); \
} \
void *dispatch::fname##_;
#define DEFINE6(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f); \
} \
void *dispatch::fname##_;
#define DEFINE7(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g); \
} \
void *dispatch::fname##_;
#define DEFINE8(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h); \
} \
void *dispatch::fname##_;
#define DEFINE9(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h, i); \
} \
void *dispatch::fname##_;
#define DEFINE10(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
t10) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
t10 j) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h, i, j); \
} \
void *dispatch::fname##_;
#define DEFINE11(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
t10, t11) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
t10 j, t11 k) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h, i, j, k); \
} \
void *dispatch::fname##_;
#define DEFINE13(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
t10, t11, t12, t13) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
t10 j, t11 k, t12 l, t13 m) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h, i, j, k, l, m); \
} \
void *dispatch::fname##_;
#define DEFINE19(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, \
t10, t11, t12, t13, t14, t15, t16, t17, t18, t19) \
ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, \
t10 j, t11 k, t12 l, t13 m, t14 n, t15 o, t16 p, t17 q, \
t18 r, t19 s) { \
return f_impl<dispatch::init>(hlib, fname, fname##_, #fname, a, b, c, d, \
e, f, g, h, i, j, k, l, m, n, o, p, q, r, \
s); \
} \
void *dispatch::fname##_;
/* ------------------- *
* CUDA
* ------------------- */
bool dispatch::cuinit() {
if (cuda_ == nullptr) {
#ifdef _WIN32
cuda_ = dlopen("cudart64_110.dll", RTLD_LAZY);
#else
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
if (!cuda_)
cuda_ = dlopen("libcuda.so.1", RTLD_LAZY);
#endif
if (!cuda_)
throw std::runtime_error("Could not find `libcuda.so`. Make sure it is "
"in your LD_LIBRARY_PATH.");
}
if (cuda_ == nullptr)
return false;
CUresult (*fptr)(unsigned int);
cuInit_ = dlsym(cuda_, "cuInit");
*reinterpret_cast<void **>(&fptr) = cuInit_;
CUresult res = (*fptr)(0);
check(res);
return true;
}
#define CUDA_DEFINE1(ret, fname, t1) DEFINE1(cuinit, cuda_, ret, fname, t1)
#define CUDA_DEFINE2(ret, fname, t1, t2) \
DEFINE2(cuinit, cuda_, ret, fname, t1, t2)
#define CUDA_DEFINE3(ret, fname, t1, t2, t3) \
DEFINE3(cuinit, cuda_, ret, fname, t1, t2, t3)
#define CUDA_DEFINE4(ret, fname, t1, t2, t3, t4) \
DEFINE4(cuinit, cuda_, ret, fname, t1, t2, t3, t4)
#define CUDA_DEFINE5(ret, fname, t1, t2, t3, t4, t5) \
DEFINE5(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5)
#define CUDA_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) \
DEFINE6(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6)
#define CUDA_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) \
DEFINE7(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
#define CUDA_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \
DEFINE8(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
#define CUDA_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
DEFINE9(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
#define CUDA_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \
DEFINE10(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
#define CUDA_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \
t11) \
DEFINE11(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \
t11)
// context management
CUDA_DEFINE1(CUresult, cuCtxDestroy_v2, CUcontext)
CUDA_DEFINE3(CUresult, cuCtxCreate_v2, CUcontext *, unsigned int, CUdevice)
CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice *)
CUDA_DEFINE2(CUresult, cuCtxEnablePeerAccess, CUcontext, unsigned int)
CUDA_DEFINE1(CUresult, cuInit, unsigned int)
CUDA_DEFINE1(CUresult, cuDriverGetVersion, int *)
// device management
CUDA_DEFINE2(CUresult, cuDeviceGet, CUdevice *, int)
CUDA_DEFINE3(CUresult, cuDeviceGetName, char *, int, CUdevice)
CUDA_DEFINE3(CUresult, cuDeviceGetPCIBusId, char *, int, CUdevice)
CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute,
CUdevice)
CUDA_DEFINE1(CUresult, cuDeviceGetCount, int *)
// link management
CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void *,
size_t, const char *, unsigned int, CUjit_option *, void **);
CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option *, void **,
CUlinkState *);
CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState);
CUDA_DEFINE3(CUresult, cuLinkComplete, CUlinkState, void **, size_t *);
// module management
CUDA_DEFINE4(CUresult, cuModuleGetGlobal_v2, CUdeviceptr *, size_t *, CUmodule,
const char *)
CUDA_DEFINE2(CUresult, cuModuleLoad, CUmodule *, const char *)
CUDA_DEFINE1(CUresult, cuModuleUnload, CUmodule)
CUDA_DEFINE2(CUresult, cuModuleLoadData, CUmodule *, const void *)
CUDA_DEFINE5(CUresult, cuModuleLoadDataEx, CUmodule *, const void *,
unsigned int, CUjit_option *, void **)
CUDA_DEFINE3(CUresult, cuModuleGetFunction, CUfunction *, CUmodule,
const char *)
// stream management
CUDA_DEFINE2(CUresult, cuStreamCreate, CUstream *, unsigned int)
CUDA_DEFINE1(CUresult, cuStreamSynchronize, CUstream)
CUDA_DEFINE1(CUresult, cuStreamDestroy_v2, CUstream)
CUDA_DEFINE2(CUresult, cuStreamGetCtx, CUstream, CUcontext *)
CUDA_DEFINE11(CUresult, cuLaunchKernel, CUfunction, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, CUstream, void **, void **)
// function management
CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int *, CUfunction_attribute,
CUfunction)
CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute,
int)
CUDA_DEFINE2(CUresult, cuFuncSetCacheConfig, CUfunction, CUfunc_cache)
// memory management
CUDA_DEFINE3(CUresult, cuMemcpyDtoH_v2, void *, CUdeviceptr, size_t)
CUDA_DEFINE1(CUresult, cuMemFree_v2, CUdeviceptr)
CUDA_DEFINE4(CUresult, cuMemcpyDtoHAsync_v2, void *, CUdeviceptr, size_t,
CUstream)
CUDA_DEFINE4(CUresult, cuMemcpyHtoDAsync_v2, CUdeviceptr, const void *, size_t,
CUstream)
CUDA_DEFINE3(CUresult, cuMemcpyHtoD_v2, CUdeviceptr, const void *, size_t)
CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr *, size_t)
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void *, CUpointer_attribute,
CUdeviceptr)
CUDA_DEFINE4(CUresult, cuMemsetD8Async, CUdeviceptr, unsigned char, size_t,
CUstream)
// event management
CUDA_DEFINE2(CUresult, cuEventCreate, CUevent *, unsigned int)
CUDA_DEFINE3(CUresult, cuEventElapsedTime, float *, CUevent, CUevent)
CUDA_DEFINE2(CUresult, cuEventRecord, CUevent, CUstream)
CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent)
/* ------------------- *
* NVML
* ------------------- */
bool dispatch::nvmlinit() {
#ifdef _WIN32
if (nvml_ == nullptr)
nvml_ = dlopen("nvml.dll", RTLD_LAZY);
#else
if (nvml_ == nullptr)
nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY);
#endif
nvmlReturn_t (*fptr)();
nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2");
*reinterpret_cast<void **>(&fptr) = nvmlInit_v2_;
nvmlReturn_t res = (*fptr)();
check(res);
return res;
}
#define NVML_DEFINE0(ret, fname) DEFINE0(nvmlinit, nvml_, ret, fname)
#define NVML_DEFINE1(ret, fname, t1) DEFINE1(nvmlinit, nvml_, ret, fname, t1)
#define NVML_DEFINE2(ret, fname, t1, t2) \
DEFINE2(nvmlinit, nvml_, ret, fname, t1, t2)
#define NVML_DEFINE3(ret, fname, t1, t2, t3) \
DEFINE3(nvmlinit, nvml_, ret, fname, t1, t2, t3)
NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *,
nvmlDevice_t *)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t,
nvmlClockType_t, unsigned int *)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t,
nvmlClockType_t, unsigned int *)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t,
unsigned int, unsigned int)
/* ------------------- *
* HIP
* ------------------- */
bool dispatch::hipinit() {
if (hip_ == nullptr)
hip_ = dlopen("libamdhip64.so", RTLD_LAZY);
if (hip_ == nullptr)
return false;
hipError_t (*fptr)();
hipInit_ = dlsym(hip_, "hipInit");
*reinterpret_cast<void **>(&fptr) = hipInit_;
hipError_t res = (*fptr)();
check(res);
return res;
}
#define HIP_DEFINE1(ret, fname, t1) DEFINE1(hipinit, hip_, ret, fname, t1)
#define HIP_DEFINE2(ret, fname, t1, t2) \
DEFINE2(hipinit, hip_, ret, fname, t1, t2)
#define HIP_DEFINE3(ret, fname, t1, t2, t3) \
DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3)
#define HIP_DEFINE4(ret, fname, t1, t2, t3, t4) \
DEFINE4(hipinit, hip_, ret, fname, t1, t2, t3, t4)
#define HIP_DEFINE5(ret, fname, t1, t2, t3, t4, t5) \
DEFINE5(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5)
#define HIP_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) \
DEFINE6(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6)
#define HIP_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) \
DEFINE7(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
#define HIP_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) \
DEFINE8(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
#define HIP_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
DEFINE9(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
#define HIP_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \
DEFINE10(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
#define HIP_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) \
DEFINE11(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, \
t11)
// context management
HIP_DEFINE1(hipError_t, hipCtxDestroy, hipCtx_t)
HIP_DEFINE3(hipError_t, hipCtxCreate, hipCtx_t *, unsigned int, hipDevice_t)
HIP_DEFINE1(hipError_t, hipCtxGetDevice, hipDevice_t *)
HIP_DEFINE1(hipError_t, hipCtxPushCurrent, hipCtx_t)
HIP_DEFINE1(hipError_t, hipCtxPopCurrent, hipCtx_t *)
HIP_DEFINE2(hipError_t, hipCtxEnablePeerAccess, hipCtx_t, unsigned int)
HIP_DEFINE1(hipError_t, hipInit, unsigned int)
HIP_DEFINE1(hipError_t, hipDriverGetVersion, int *)
// device management
HIP_DEFINE2(hipError_t, hipGetDevice, hipDevice_t *, int)
HIP_DEFINE3(hipError_t, hipDeviceGetName, char *, int, hipDevice_t)
HIP_DEFINE3(hipError_t, hipDeviceGetPCIBusId, char *, int, hipDevice_t)
HIP_DEFINE3(hipError_t, hipDeviceGetAttribute, int *, hipDeviceAttribute_t,
hipDevice_t)
HIP_DEFINE1(hipError_t, hipGetDeviceCount, int *)
// module management
HIP_DEFINE4(hipError_t, hipModuleGetGlobal, hipDeviceptr_t *, size_t *,
hipModule_t, const char *)
HIP_DEFINE2(hipError_t, hipModuleLoad, hipModule_t *, const char *)
HIP_DEFINE1(hipError_t, hipModuleUnload, hipModule_t)
HIP_DEFINE2(hipError_t, hipModuleLoadData, hipModule_t *, const void *)
HIP_DEFINE5(hipError_t, hipModuleLoadDataEx, hipModule_t *, const void *,
unsigned int, hipJitOption *, void **)
HIP_DEFINE3(hipError_t, hipModuleGetFunction, hipFunction_t *, hipModule_t,
const char *)
// stream management
HIP_DEFINE2(hipError_t, hipStreamCreate, hipStream_t *, unsigned int)
HIP_DEFINE1(hipError_t, hipStreamSynchronize, hipStream_t)
HIP_DEFINE1(hipError_t, hipStreamDestroy, hipStream_t)
HIP_DEFINE11(hipError_t, hipModuleLaunchKernel, hipFunction_t, unsigned int,
unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, unsigned int, hipStream_t, void **, void **)
// function management
HIP_DEFINE2(hipError_t, hipFuncGetAttributes, hipFuncAttributes *, void *)
HIP_DEFINE2(hipError_t, hipFuncSetCacheConfig, hipFunction_t, hipFuncCache_t)
// memory management
HIP_DEFINE3(hipError_t, hipMemcpyDtoH, void *, hipDeviceptr_t, size_t)
HIP_DEFINE1(hipError_t, hipFree, hipDeviceptr_t)
HIP_DEFINE4(hipError_t, hipMemcpyDtoHAsync, void *, hipDeviceptr_t, size_t,
hipStream_t)
HIP_DEFINE4(hipError_t, hipMemcpyHtoDAsync, hipDeviceptr_t, const void *,
size_t, hipStream_t)
HIP_DEFINE3(hipError_t, hipMemcpyHtoD, hipDeviceptr_t, const void *, size_t)
HIP_DEFINE2(hipError_t, hipMalloc, hipDeviceptr_t *, size_t)
HIP_DEFINE3(hipError_t, hipPointerGetAttribute, void *, CUpointer_attribute,
hipDeviceptr_t)
HIP_DEFINE4(hipError_t, hipMemsetD8Async, hipDeviceptr_t, unsigned char, size_t,
hipStream_t)
// event management
HIP_DEFINE2(hipError_t, hipEventCreate, hipEvent_t *, unsigned int)
HIP_DEFINE3(hipError_t, hipEventElapsedTime, float *, hipEvent_t, hipEvent_t)
HIP_DEFINE2(hipError_t, hipEventRecord, hipEvent_t, hipStream_t)
HIP_DEFINE1(hipError_t, hipEventDestroy, hipEvent_t)
/* ------------------- *
* COMMON
* ------------------- */
// Release
void dispatch::release() {
if (cuda_) {
dlclose(cuda_);
cuda_ = nullptr;
}
}
void *dispatch::cuda_;
void *dispatch::nvml_;
void *dispatch::nvmlInit_v2_;
void *dispatch::hip_;
} // namespace driver
} // namespace triton

View File

@@ -1,270 +0,0 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "triton/driver/error.h"
namespace triton {
namespace driver {
void check(CUresult err) {
using namespace exception::cuda;
switch (err) {
case CUDA_SUCCESS:
break;
case CUDA_ERROR_INVALID_VALUE:
throw invalid_value();
case CUDA_ERROR_OUT_OF_MEMORY:
throw out_of_memory();
case CUDA_ERROR_NOT_INITIALIZED:
throw not_initialized();
case CUDA_ERROR_DEINITIALIZED:
throw deinitialized();
case CUDA_ERROR_PROFILER_DISABLED:
throw profiler_disabled();
case CUDA_ERROR_PROFILER_NOT_INITIALIZED:
throw profiler_not_initialized();
case CUDA_ERROR_PROFILER_ALREADY_STARTED:
throw profiler_already_started();
case CUDA_ERROR_PROFILER_ALREADY_STOPPED:
throw profiler_already_stopped();
case CUDA_ERROR_NO_DEVICE:
throw no_device();
case CUDA_ERROR_INVALID_DEVICE:
throw invalid_device();
case CUDA_ERROR_INVALID_IMAGE:
throw invalid_image();
case CUDA_ERROR_INVALID_CONTEXT:
throw invalid_context();
case CUDA_ERROR_CONTEXT_ALREADY_CURRENT:
throw context_already_current();
case CUDA_ERROR_MAP_FAILED:
throw map_failed();
case CUDA_ERROR_UNMAP_FAILED:
throw unmap_failed();
case CUDA_ERROR_ARRAY_IS_MAPPED:
throw array_is_mapped();
case CUDA_ERROR_ALREADY_MAPPED:
throw already_mapped();
case CUDA_ERROR_NO_BINARY_FOR_GPU:
throw no_binary_for_gpu();
case CUDA_ERROR_ALREADY_ACQUIRED:
throw already_acquired();
case CUDA_ERROR_NOT_MAPPED:
throw not_mapped();
case CUDA_ERROR_NOT_MAPPED_AS_ARRAY:
throw not_mapped_as_array();
case CUDA_ERROR_NOT_MAPPED_AS_POINTER:
throw not_mapped_as_pointer();
case CUDA_ERROR_ECC_UNCORRECTABLE:
throw ecc_uncorrectable();
case CUDA_ERROR_UNSUPPORTED_LIMIT:
throw unsupported_limit();
case CUDA_ERROR_CONTEXT_ALREADY_IN_USE:
throw context_already_in_use();
case CUDA_ERROR_PEER_ACCESS_UNSUPPORTED:
throw peer_access_unsupported();
case CUDA_ERROR_INVALID_PTX:
throw invalid_ptx();
case CUDA_ERROR_INVALID_GRAPHICS_CONTEXT:
throw invalid_graphics_context();
case CUDA_ERROR_INVALID_SOURCE:
throw invalid_source();
case CUDA_ERROR_FILE_NOT_FOUND:
throw file_not_found();
case CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND:
throw shared_object_symbol_not_found();
case CUDA_ERROR_SHARED_OBJECT_INIT_FAILED:
throw shared_object_init_failed();
case CUDA_ERROR_OPERATING_SYSTEM:
throw operating_system();
case CUDA_ERROR_INVALID_HANDLE:
throw invalid_handle();
case CUDA_ERROR_NOT_FOUND:
throw not_found();
case CUDA_ERROR_NOT_READY:
throw not_ready();
case CUDA_ERROR_ILLEGAL_ADDRESS:
throw illegal_address();
case CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES:
throw launch_out_of_resources();
case CUDA_ERROR_LAUNCH_TIMEOUT:
throw launch_timeout();
case CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING:
throw launch_incompatible_texturing();
case CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED:
throw peer_access_already_enabled();
case CUDA_ERROR_PEER_ACCESS_NOT_ENABLED:
throw peer_access_not_enabled();
case CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE:
throw primary_context_active();
case CUDA_ERROR_CONTEXT_IS_DESTROYED:
throw context_is_destroyed();
case CUDA_ERROR_ASSERT:
throw assert_error();
case CUDA_ERROR_TOO_MANY_PEERS:
throw too_many_peers();
case CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED:
throw host_memory_already_registered();
case CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED:
throw host_memory_not_registered();
case CUDA_ERROR_HARDWARE_STACK_ERROR:
throw hardware_stack_error();
case CUDA_ERROR_ILLEGAL_INSTRUCTION:
throw illegal_instruction();
case CUDA_ERROR_MISALIGNED_ADDRESS:
throw misaligned_address();
case CUDA_ERROR_INVALID_ADDRESS_SPACE:
throw invalid_address_space();
case CUDA_ERROR_INVALID_PC:
throw invalid_pc();
case CUDA_ERROR_LAUNCH_FAILED:
throw launch_failed();
case CUDA_ERROR_NOT_PERMITTED:
throw not_permitted();
case CUDA_ERROR_NOT_SUPPORTED:
throw not_supported();
case CUDA_ERROR_UNKNOWN:
throw unknown();
default:
throw unknown();
}
}
void check(hipError_t error) {
using namespace exception::hip;
switch (error) {
case hipSuccess:
break;
case hipErrorInvalidValue:
throw invalid_value();
case hipErrorMemoryAllocation:
throw out_of_memory();
case hipErrorNotInitialized:
throw not_initialized();
case hipErrorDeinitialized:
throw deinitialized();
case hipErrorProfilerDisabled:
throw profiler_disabled();
case hipErrorProfilerNotInitialized:
throw profiler_not_initialized();
case hipErrorProfilerAlreadyStarted:
throw profiler_already_started();
case hipErrorProfilerAlreadyStopped:
throw profiler_already_stopped();
case hipErrorNoDevice:
throw no_device();
case hipErrorInvalidSymbol:
throw invalid_symbol();
case hipErrorInvalidDevice:
throw invalid_device();
case hipErrorInvalidImage:
throw invalid_image();
case hipErrorInvalidContext:
throw invalid_context();
case hipErrorContextAlreadyCurrent:
throw context_already_current();
case hipErrorMapFailed:
throw map_failed();
case hipErrorUnmapFailed:
throw unmap_failed();
case hipErrorArrayIsMapped:
throw array_is_mapped();
case hipErrorAlreadyMapped:
throw already_mapped();
case hipErrorNoBinaryForGpu:
throw no_binary_for_gpu();
case hipErrorAlreadyAcquired:
throw already_acquired();
case hipErrorNotMapped:
throw not_mapped();
case hipErrorNotMappedAsArray:
throw not_mapped_as_array();
case hipErrorNotMappedAsPointer:
throw not_mapped_as_pointer();
case hipErrorECCNotCorrectable:
throw ecc_uncorrectable();
case hipErrorUnsupportedLimit:
throw unsupported_limit();
case hipErrorContextAlreadyInUse:
throw context_already_in_use();
case hipErrorPeerAccessUnsupported:
throw peer_access_unsupported();
case hipErrorInvalidKernelFile:
throw invalid_ptx();
case hipErrorInvalidGraphicsContext:
throw invalid_graphics_context();
case hipErrorInvalidSource:
throw invalid_source();
case hipErrorFileNotFound:
throw file_not_found();
case hipErrorSharedObjectSymbolNotFound:
throw shared_object_symbol_not_found();
case hipErrorSharedObjectInitFailed:
throw shared_object_init_failed();
case hipErrorOperatingSystem:
throw operating_system();
case hipErrorInvalidResourceHandle:
throw invalid_handle();
case hipErrorNotFound:
throw not_found();
case hipErrorNotReady:
throw not_ready();
case hipErrorIllegalAddress:
throw illegal_address();
case hipErrorLaunchOutOfResources:
throw launch_out_of_resources();
case hipErrorLaunchTimeOut:
throw launch_timeout();
// case hipErrorLaunchIncompatibleTexturing : throw
// launch_incompatible_texturing();
case hipErrorPeerAccessAlreadyEnabled:
throw peer_access_already_enabled();
case hipErrorPeerAccessNotEnabled:
throw peer_access_not_enabled();
// case hipErrorPrimaryContextActive : throw primary_context_active();
// case hipErrorContextIsDestroyed : throw context_is_destroyed();
case hipErrorAssert:
throw assert_error();
// case hipErrorTooManyPeers : throw too_many_peers();
case hipErrorHostMemoryAlreadyRegistered:
throw host_memory_already_registered();
case hipErrorHostMemoryNotRegistered:
throw host_memory_not_registered();
// case hipErrorHardwareStackError : throw hardware_stack_error();
// case hipErrorIllegalInstruction : throw illegal_instruction();
// case hipErrorMisalignedAddress : throw misaligned_address();
// case hipErrorInvalidAddressSpace : throw invalid_address_space();
// case hipErrorInvalidPc : throw invalid_pc();
case hipErrorLaunchFailure:
throw launch_failed();
// case hipErrorNotPermitted : throw not_permitted();
case hipErrorNotSupported:
throw not_supported();
case hipErrorUnknown:
throw unknown();
default:
throw unknown();
}
}
} // namespace driver
} // namespace triton

View File

@@ -1,392 +0,0 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include <fstream>
#if defined __has_include
#if __has_include(<unistd.h>)
#include <unistd.h>
#endif
#endif
#include "triton/driver/dispatch.h"
#include "triton/driver/error.h"
#include "triton/driver/llvm.h"
#include "triton/tools/sha1.hpp"
#include "triton/tools/sys/exec.hpp"
#include "triton/tools/sys/getenv.hpp"
#include "triton/tools/sys/mkdir.hpp"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include <memory>
#include <regex>
// begin AMD stuff
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/FormattedStream.h"
#include "llvm/Support/Program.h"
#include "llvm/Support/ToolOutputFile.h"
// end AMD stuff
extern "C" {
int set_curterm(char *nterm) { return 0; }
int del_curterm(char *nterm) { return 0; }
int tigetnum(char *capname) { return 0; }
int setupterm(char *term, int fildes, int *errret) { return 0; }
}
namespace triton {
namespace driver {
void init_llvm() {
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
LLVMInitializeAMDGPUTargetInfo();
LLVMInitializeAMDGPUTarget();
LLVMInitializeAMDGPUTargetMC();
LLVMInitializeAMDGPUAsmPrinter();
}
/* ------------------------ */
// CUDA //
/* ------------------------ */
static bool find_and_replace(std::string &str, const std::string &begin,
const std::string &end,
const std::string &target) {
size_t start_replace = str.find(begin);
if (start_replace == std::string::npos)
return false;
size_t end_replace = str.find(end, start_replace);
if (end_replace == std::string::npos)
return false;
str.replace(start_replace, end_replace + 1 - start_replace, target);
return true;
}
std::string path_to_ptxas(int &version) {
std::vector<std::string> rets;
std::string ret;
// search paths for ptxas
std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
if (!triton_ptxas.empty())
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
// see what path for ptxas are valid
std::vector<std::string> working_ptxas;
for (const std::string &prefix : ptxas_prefixes) {
std::string ptxas = prefix + "ptxas";
bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
if (works) {
working_ptxas.push_back(ptxas);
rets.push_back(ret);
}
}
// error if no working ptxas was found
if (working_ptxas.empty())
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, "
"/usr/local/cuda/bin/ or PATH"
" but a working version could not be found.");
std::string ptxas = working_ptxas.front();
// parse version
std::regex version_regex("release (\\d+)\\.(\\d+)");
std::smatch match;
bool found = false;
// currently choosing the first ptxas. Other logics can be implemented in
// future
size_t i = 0;
while (i < rets.size()) {
if (std::regex_search(rets[i], match, version_regex)) {
int major = std::stoi(match[1]);
int minor = std::stoi(match[2]);
version = major * 1000 + minor * 10;
found = true;
break;
}
++i;
}
if (not found) {
throw std::runtime_error("Error in parsing version");
}
return working_ptxas[i];
}
int vptx(int version) {
if (version >= 11040)
return 74;
if (version >= 11030)
return 73;
if (version >= 11020)
return 72;
if (version >= 11010)
return 71;
if (version >= 11000)
return 70;
if (version >= 10020)
return 65;
if (version >= 10010)
return 64;
if (version >= 10000)
return 63;
throw std::runtime_error("Triton requires CUDA 10+");
}
std::string llir_to_ptx(llvm::Module *module, int cc, int version) {
// LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75;
int max_nvvm_ptx = 74;
// options
auto options = llvm::cl::getRegisteredOptions();
auto *short_ptr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(short_ptr);
short_ptr->setValue(true);
// compute capability
std::string sm = "sm_" + std::to_string(cc);
// max PTX version
int ptx = vptx(version);
int ptx_major = ptx / 10;
int ptx_minor = ptx % 10;
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
std::string layout = "";
std::string features = "";
// std::string features = "+ptx" + std::to_string(std::min(ptx,
// max_nvvm_ptx));
init_llvm();
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(*module);
// module->print(llvm::outs(), nullptr);
// create machine
module->setTargetTriple(triple);
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if (layout.empty())
module->setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
// emit machine code
for (llvm::Function &f : module->functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// emit
machine->addPassesToEmitFile(pass, stream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(*module);
// post-process
std::string result(buffer.begin(), buffer.end());
find_and_replace(result, ".version", "\n",
".version " + std::to_string(ptx_major) + "." +
std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
;
while (find_and_replace(result, "\t// end inline asm", "\n", ""))
;
return result;
}
std::string ptx_to_cubin(const std::string &ptx, const std::string &ptxas,
int cc) {
// compile ptx with ptxas
char _fsrc[L_tmpnam];
char _flog[L_tmpnam];
std::tmpnam(_fsrc);
std::tmpnam(_flog);
std::string fsrc = _fsrc;
std::string flog = _flog;
std::string fbin = fsrc + ".o";
const char *_fbin = fbin.c_str();
std::ofstream ofs(fsrc);
ofs << ptx << std::endl;
ofs.close();
std::string cmd;
int err;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc +
" -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str());
if (err != 0) {
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
unlink(_fsrc);
unlink(_flog);
throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
}
CUmodule ret;
std::ifstream _cubin(_fbin, std::ios::binary);
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
unlink(_fsrc);
unlink(_flog);
unlink(_fbin);
dispatch::cuModuleLoadData(&ret, cubin.c_str());
return cubin;
}
/* ------------------------ */
// HIP //
/* ------------------------ */
std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc) {
init_llvm();
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "amdgcn-amd-amdhsa";
std::string layout = "";
std::string features;
std::string proc = "gfx908";
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(*module);
// create machine
module->setTargetTriple(triple);
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if (layout.empty())
module->setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
// emit machine code
for (llvm::Function &f : module->functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// create dump files
std::string module_name = module->getModuleIdentifier();
std::error_code ec;
// Save GCN ISA binary.
std::string isabin_path =
std::string("/tmp/") + module_name + std::string(".o");
std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
if (ec) {
std::cout << isabin_path << " was not created. error code: " << ec
<< std::endl;
}
// emit
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr,
llvm::CGFT_ObjectFile);
pass.run(*module);
// Save GCN ISA.
std::string amdgcn_path =
std::string("/tmp/") + module_name + std::string(".gcn");
std::string result(buffer.begin(), buffer.end());
std::ofstream amdgcn(amdgcn_path);
amdgcn << result;
amdgcn.close();
// generate HASCO file
std::string hsaco_path =
std::string("/tmp/") + module_name + std::string(".hsaco");
std::string error_message;
int lld_result =
llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld",
{"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu",
"-shared", "-o", hsaco_path, isabin_path},
llvm::None, {}, 0, 0, &error_message);
if (lld_result) {
std::cout << "ld.lld execute fail: " << std::endl;
std::cout << error_message << std::endl;
std::cout << lld_result << std::endl;
}
return hsaco_path;
}
hipModule_t amdgpu_to_hipmodule(const std::string &path) {
// Read HSACO.
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
std::vector<unsigned char> hsaco(hsaco_file_size);
hsaco_file.seekg(0, std::ios::beg);
hsaco_file.read(reinterpret_cast<char *>(&hsaco[0]), hsaco_file_size);
hsaco_file.close();
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes,
hipJitOptionErrorLogBuffer,
hipJitOptionInfoLogBufferSizeBytes,
hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose};
const unsigned int errbufsize = 8192;
const unsigned int logbufsize = 8192;
char _err[errbufsize];
char _log[logbufsize];
void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err,
(void *)(uintptr_t)logbufsize, (void *)_log, (void *)1};
hipModule_t ret;
dispatch::hipModuleLoadDataEx(&ret, hsaco.data(), 5, opt, optval);
return ret;
}
} // namespace driver
} // namespace triton

View File

@@ -1,7 +1,4 @@
#include "triton/driver/error.h"
#include "triton/driver/llvm.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
@@ -10,6 +7,9 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
@@ -24,10 +24,14 @@
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/SourceMgr.h"
#include <Python.h>
#include <cctype>
#include <fstream>
#include <optional>
#include <pybind11/buffer_info.h>
#include <pybind11/functional.h>
@@ -40,10 +44,6 @@
#include <string>
namespace py = pybind11;
// namespace ir = triton::ir;
namespace drv = triton::driver;
using triton::cuGetInfo;
enum backend_t {
HOST,
@@ -51,306 +51,6 @@ enum backend_t {
ROCM,
};
void cu_enable_peer_access(uint64_t peer_ptr) {
CUcontext context;
drv::dispatch::cuPointerGetAttribute(&context, CU_POINTER_ATTRIBUTE_CONTEXT,
peer_ptr);
try {
drv::dispatch::cuCtxEnablePeerAccess(context, 0);
} catch (drv::exception::cuda::peer_access_already_enabled) {
}
}
void host_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
uint64_t grid_1, uint64_t grid_2, uint64_t block_0,
uint64_t block_1, uint64_t block_2, void *args_ptr,
size_t args_size, int64_t shared_mem) {
throw std::runtime_error("unsupported");
// auto hst = kernel->module()->hst();
// hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]);
// char* params = new char[args_size];
// std::memcpy((void*)params, (void*)args, args_size);
// for(size_t i = 0; i < grid[0]; i++)
// for(size_t j = 0; j < grid[1]; j++)
// for(size_t k = 0; k < grid[2]; k++)
// hst_->futures->emplace_back(hst_->pool->enqueue(hst->fn,
// (char**)params, int32_t(i), int32_t(j), int32_t(k)));
}
void cu_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
uint64_t grid_1, uint64_t grid_2, uint64_t block_0,
uint64_t block_1, uint64_t block_2, void *args_ptr,
size_t args_size, int64_t shared_mem) {
void *config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, (void *)args_ptr,
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
CU_LAUNCH_PARAM_END};
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
block_0, block_1, block_2, shared_mem,
(CUstream)stream, nullptr, config);
}
long pow2_divisor(long N) {
if (N % 16 == 0)
return 16;
if (N % 8 == 0)
return 8;
if (N % 4 == 0)
return 4;
if (N % 2 == 0)
return 2;
return 1;
}
// Returns something like "int16", whether dtype is a torch.dtype or
// triton.language.dtype.
std::string dtype_cache_key_part(const py::object &dtype) {
if (py::hasattr(dtype, "cache_key_part")) {
// Presumed to be a triton.language.dtype.
return std::string(py::str(py::getattr(dtype, "cache_key_part")));
} else {
// Remove 'torch.' prefix from repr of torch.dtype.
py::object repr = py::repr(dtype);
size_t repr_len = PyUnicode_GET_LENGTH(repr.ptr());
const char *repr_ptr = (const char *)PyUnicode_1BYTE_DATA(repr.ptr());
if (repr_len <= 6 || strncmp(repr_ptr, "torch.", 6)) {
throw std::logic_error("invalid dtype: " +
std::string(repr_ptr, repr_len));
}
return std::string(repr_ptr + 6, repr_len - 6);
}
}
size_t get_pointer_range_size(uint64_t addr) {
if (addr == 0)
return 0;
size_t size;
drv::dispatch::cuPointerGetAttribute(&size, CU_POINTER_ATTRIBUTE_RANGE_SIZE,
(CUdeviceptr)addr);
return size;
}
// Launch
void parse_args(py::list &args, py::list do_not_specialize,
const std::string &func_key, py::list &arg_names,
std::string &cache_key, std::string &params,
size_t &params_size, py::dict constants, int num_warps,
int num_stages) {
size_t len = PyList_Size(args.ptr());
params.reserve(8 * len); // 8 max bytes by argument
char *params_ptr = &params[0];
cache_key = func_key;
cache_key += "-" + std::to_string(num_warps);
cache_key += "-" + std::to_string(num_stages);
cache_key += "-";
for (int i = 0; i < len; i++) {
cache_key += "_";
py::int_ py_i = py::int_(i);
bool specialize = !do_not_specialize.contains(py_i);
py::object arg = args[i];
auto arg_ptr = arg.ptr();
// argument is `long`
if (PyLong_Check(arg_ptr)) {
int overflow;
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
// values equal to 1 are specialized
if (specialize && (value == 1)) {
cache_key += "1";
continue;
}
// int32, uint32, int64, and uint64 have different kernels
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
cache_key += "int32";
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow && 0x8000'0000LL <= value &&
value <= 0xFFFF'FFFFLL) {
cache_key += "uint32";
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow) {
cache_key += "int64";
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
} else {
if (PyErr_Occurred()) {
throw std::logic_error("An error occurred?");
}
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
if (PyErr_Occurred()) {
throw std::runtime_error("integer overflow in argument: " +
std::string(py::str(arg)));
}
cache_key += "uint64";
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &unsigned_value, 8);
params_ptr += 8;
}
if (!specialize)
continue;
// values divisible by small powers of 2 are specialized
cache_key += "[multipleof(";
cache_key += std::to_string(pow2_divisor(value));
cache_key += ")]";
continue;
}
// argument is `float`
if (PyFloat_Check(arg_ptr)) {
cache_key += "float32";
float value = PyFloat_AsDouble(arg_ptr);
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
continue;
}
// argument is `bool`
if (PyBool_Check(arg_ptr)) {
cache_key += "bool";
bool value = arg_ptr == Py_True ? true : false;
std::memcpy(params_ptr, &value, 1);
params_ptr += 1;
continue;
}
// argument is tensor
if (py::hasattr(arg, "data_ptr")) {
py::object data_ptr = arg.attr("data_ptr")();
long value = data_ptr.cast<long>();
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
// copy param
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
// update cache key
cache_key += dtype_cache_key_part(arg.attr("dtype"));
cache_key += "*";
cache_key += "[multipleof(";
size_t range_size = get_pointer_range_size(value);
cache_key += std::to_string(
std::min(pow2_divisor(value), pow2_divisor(range_size)));
cache_key += ")]";
continue;
}
// argument is `constexpr`
if (py::hasattr(arg, "value")) {
py::object value = arg.attr("value");
py::object name = arg_names[i];
constants[name] = value;
py::object repr = py::repr(value);
const char *start = (const char *)PyUnicode_1BYTE_DATA(repr.ptr());
size_t len = PyUnicode_GET_LENGTH(repr.ptr());
cache_key += std::string(start, len);
continue;
}
std::string ty_str =
arg.attr("__class__").attr("__name__").cast<std::string>();
if (ty_str == "NoneType") {
cache_key += "None";
continue;
}
std::string err_msg = "Received type '" + ty_str + "' for argument " +
std::to_string(i) + "." +
" Only int, float, bool, torch.Tensor, and "
"triton.language.constexpr are supported.";
throw std::runtime_error(err_msg);
}
params_size = (std::ptrdiff_t)(params_ptr - &params[0]);
}
void parse_args(py::list &args, py::list &arg_names, std::string &params,
size_t &params_size, py::dict constants) {
size_t len = PyList_Size(args.ptr());
params.reserve(8 * len); // 8 max bytes by argument
char *params_ptr = params.data();
for (int i = 0; i < len; i++) {
py::object arg = args[i];
auto arg_ptr = arg.ptr();
if (PyLong_Check(arg_ptr)) {
int overflow{};
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow && 0x8000'0000LL <= value &&
value <= 0xFFFF'FFFFLL) {
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow) {
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
} else {
if (PyErr_Occurred()) {
throw std::logic_error("An error occurred?");
}
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
if (PyErr_Occurred()) {
throw std::runtime_error("integer overflow in argument: " +
std::string(py::str(arg)));
}
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &unsigned_value, 8);
params_ptr += 8;
}
continue;
}
if (PyFloat_Check(arg_ptr)) {
float value = PyFloat_AsDouble(arg_ptr);
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
continue;
}
// argument is `bool`
if (PyBool_Check(arg_ptr)) {
bool value = arg_ptr == Py_True ? true : false;
std::memcpy(params_ptr, &value, 1);
params_ptr += 1;
continue;
}
// argument is torch.tensor, get data_ptr as memory address
if (py::hasattr(arg, "data_ptr")) {
py::object data_ptr = arg.attr("data_ptr")();
long value = data_ptr.cast<long>();
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
// copy param
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
// update cache key
continue;
}
// argument is `constexpr`
if (py::hasattr(arg, "value")) {
py::object value = arg.attr("value");
py::object name = arg_names[i];
constants[name] = value;
continue;
}
// argument is `LoadedBinary`
if (py::hasattr(arg, "get_sass")) {
// Do nothing, just a placeholder here to indicate validity.
continue;
}
std::string ty_str =
arg.attr("__class__").attr("__name__").cast<std::string>();
std::string err_msg = "Received type '" + ty_str + "' for argument " +
std::to_string(i) + "." +
" Only int, float, bool, torch.Tensor, and "
"triton.language.constexpr are supported.";
throw std::runtime_error(err_msg);
}
params_size = (std::ptrdiff_t)(params_ptr - &params[0]);
}
void init_triton_runtime(py::module &&m) {
// wrap backend_t
py::enum_<backend_t>(m, "backend")
@@ -358,192 +58,8 @@ void init_triton_runtime(py::module &&m) {
.value("CUDA", CUDA)
// .value("ROCM", ROCM)
.export_values();
// enable peer-to-peer
m.def("enable_peer_access", [](backend_t backend, uint64_t peer_ptr) {
if (backend != CUDA)
throw std::runtime_error("P2P only supported on CUDA devices!");
cu_enable_peer_access(peer_ptr);
});
// get range size for the given pointer
m.def("get_pointer_range_size", &get_pointer_range_size);
// cache key
m.def("launch", [](py::list args, py::list do_not_specialize,
const std::string &func_key, py::list &arg_names,
py::object device, py::int_ stream, py::dict bin_cache,
py::int_ num_warps, py::int_ num_stages,
py::function add_to_cache, py::object grid) {
// parse arguments to compute cache key, compile-time constants and packed
// kernel arguments
long _num_warps = PyLong_AsLong(num_warps.ptr());
long _num_stages = PyLong_AsLong(num_stages.ptr());
std::string cache_key;
std::string params;
size_t params_size;
py::dict constants;
parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params,
params_size, constants, _num_warps, _num_stages);
// get cached binary
py::str key(cache_key);
py::bool_ noop = false;
if (!bin_cache.contains(key)) {
noop = add_to_cache(key, args, device, num_warps, num_stages);
}
if (noop)
return (py::object)py::none();
py::object bin = bin_cache[key];
// get grid
py::sequence seq;
if (!PySequence_Check(grid.ptr()))
seq = grid(constants);
else
seq = grid;
int size = seq.size();
int grid_0 = py::cast<int>(seq[0]);
int grid_1 = size < 2 ? 1 : py::cast<int>(seq[1]);
int grid_2 = size < 3 ? 1 : py::cast<int>(seq[2]);
// enqueue
uint64_t kernel = py::cast<uint64_t>(bin.attr("kernel"));
uint64_t shared_mem = py::cast<uint64_t>(bin.attr("shared_mem"));
// actually launch
void *config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
CU_LAUNCH_PARAM_BUFFER_SIZE, &params_size,
CU_LAUNCH_PARAM_END};
uint64_t _stream = PyLong_AsLong(stream.ptr());
if (grid_0 * grid_1 * grid_2 > 0) {
// release the gil in case the enqueue blocks
// cuda will block if too many ops are enqueued
py::gil_scoped_release allow_threads;
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps * 32, 1, 1, shared_mem,
(CUstream)_stream, nullptr, config);
}
return bin;
});
m.def("cc", [](backend_t backend, uint64_t device) -> int {
if (backend == CUDA) {
CUdevice dev = (CUdevice)device;
int major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
int minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
return major * 10 + minor;
}
return -1;
});
m.def("launch_binary", [](py::object binary, py::list args,
py::list do_not_specialize, py::list arg_names,
py::int_ stream, py::int_ num_warps,
py::int_ num_stages, py::object grid) {
long _num_warps = PyLong_AsLong(num_warps.ptr());
long _num_stages = PyLong_AsLong(num_stages.ptr());
// get grid
py::sequence seq;
py::dict constants;
std::string params;
size_t params_size{};
parse_args(args, arg_names, params, params_size, constants);
if (!PySequence_Check(grid.ptr()))
seq = grid(constants);
else
seq = grid;
int size = seq.size();
int grid_0 = py::cast<int>(seq[0]);
int grid_1 = size < 2 ? 1 : py::cast<int>(seq[1]);
int grid_2 = size < 3 ? 1 : py::cast<int>(seq[2]);
uint64_t kernel = py::cast<uint64_t>(binary.attr("kernel"));
uint64_t shared_mem = py::cast<uint64_t>(binary.attr("shared_mem"));
// actually launch
void *config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
CU_LAUNCH_PARAM_BUFFER_SIZE, &params_size,
CU_LAUNCH_PARAM_END};
uint64_t _stream = PyLong_AsLong(stream.ptr());
const int numGrids = grid_0 * grid_1 * grid_2;
if (numGrids) {
// release the gil in case the enqueue blocks
// cuda will block if too many ops are enqueued
py::gil_scoped_release allow_threads;
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps * 32, 1, 1, shared_mem,
(CUstream)_stream, nullptr, config);
}
return binary;
});
// query maximum shared memory
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
if (backend == HOST)
return 0;
if (backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(
device);
return -1;
});
// query DRAM & L2 cache
m.def("memory_clock_rate", [](backend_t backend, uint64_t device) {
if (backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE>(device);
return -1;
});
m.def("global_memory_bus_width", [](backend_t backend, uint64_t device) {
if (backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH>(device);
return -1;
});
m.def("l2_cache_size", [](backend_t backend, uint64_t device) {
if (backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE>(device);
return -1;
});
// query clock rate (in kilohertz)
m.def("clock_rate", [](backend_t backend, uint64_t device) {
if (backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_CLOCK_RATE>(device);
return -1;
});
m.def("num_sm", [](backend_t backend, uint64_t device) {
if (backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT>(device);
return -1;
});
// enqueue
m.def("enqueue",
[](backend_t backend, uint64_t stream, uint64_t kernel, uint64_t grid_0,
uint64_t grid_1, uint64_t grid_2, uint64_t block_0, uint64_t block_1,
uint64_t block_2, const std::string &args, int64_t shared_mem) {
void *args_ptr = (void *)args.data();
size_t args_size = args.size();
// release the gil in case the enqueue blocks
// cuda will block if too many ops are enqueued
py::gil_scoped_release allow_threads;
if (backend == HOST)
host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0,
block_1, block_2, args_ptr, args_size, shared_mem);
if (backend == CUDA)
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1,
block_2, args_ptr, args_size, shared_mem);
});
}
/*****************************************************************************/
/* Python bindings for triton::codegen */
/*****************************************************************************/
typedef std::map<std::string, py::object> asm_map_t;
/*****************************************************************************/
/* Python bindings for triton::ir */
/*****************************************************************************/
@@ -783,6 +299,38 @@ void init_triton_ir(py::module &&m) {
return self.lookupSymbol<mlir::FuncOp>(funcName);
});
m.def(
"parse_mlir_module",
[](const std::string &inputFilename, mlir::MLIRContext &context) {
// open file
std::string errorMessage;
auto input = mlir::openInputFile(inputFilename, &errorMessage);
if (!input)
throw std::runtime_error(errorMessage);
// initialize registry
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonDialect,
mlir::triton::gpu::TritonGPUDialect,
mlir::math::MathDialect, mlir::arith::ArithmeticDialect,
mlir::StandardOpsDialect, mlir::scf::SCFDialect>();
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
context.allowUnregisteredDialects();
// parse module
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
mlir::OwningOpRef<mlir::ModuleOp> module(
mlir::parseSourceFile(sourceMgr, &context));
if (!module)
throw std::runtime_error("Parse MLIR file failed.");
return module->clone();
},
ret::take_ownership);
py::class_<mlir::FuncOp, mlir::OpState>(m, "function")
// .def_property_readonly("attrs", &ir::function::attrs)
// .def("add_attr", &ir::function::add_attr);
@@ -1643,84 +1191,86 @@ void init_triton_ir(py::module &&m) {
}
void init_triton_translation(py::module &m) {
m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string {
llvm::LLVMContext llvmContext;
auto llvmModule =
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
std::string str;
llvm::raw_string_ostream os(str);
llvmModule->print(os, nullptr);
os.flush();
return str;
using ret = py::return_value_policy;
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
auto pass = std::make_unique<mlir::Allocation>(module);
return pass->getSharedMemorySize();
});
m.def("translate_triton_gpu_to_ptx",
[](mlir::ModuleOp module, uint64_t device)
-> std::tuple<std::string /*ptx code*/, size_t /*shem size*/> {
auto [ptxCode, cc, version, ptxasPath] =
triton::translateTritonGPUToPTX(module, device);
m.def(
"translate_triton_gpu_to_llvmir",
[](mlir::ModuleOp op) {
llvm::LLVMContext llvmContext;
auto llvmModule =
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
mlir::PassManager pm(module->getContext());
auto pass = std::make_unique<mlir::Allocation>(module);
size_t size = pass->getSharedMemorySize();
std::string str;
llvm::raw_string_ostream os(str);
llvmModule->print(os, nullptr);
os.flush();
return str;
},
ret::take_ownership);
return std::make_tuple(ptxCode, size);
});
m.def(
"translate_llvmir_to_ptx",
[](const std::string llvmIR, int capability, int version) -> std::string {
// create LLVM module from C++
llvm::LLVMContext context;
std::unique_ptr<llvm::MemoryBuffer> buffer =
llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str());
llvm::SMDiagnostic error;
std::unique_ptr<llvm::Module> module =
llvm::parseIR(buffer->getMemBufferRef(), error, context);
// translate module to PTX
auto ptxCode =
triton::translateLLVMIRToPTX(*module, capability, version);
return ptxCode;
},
ret::take_ownership);
m.def("compile_ptx_to_cubin",
[](const std::string &ptxCode, uint64_t device) -> py::object {
[](const std::string &ptxCode, const std::string &ptxasPath,
int capability) -> py::object {
py::gil_scoped_release allow_threads;
int version;
int cc;
std::string ptxasPath;
triton::getCuCCAndVersionFromDevice(device, &cc, &version,
&ptxasPath);
std::string cubin = drv::ptx_to_cubin(ptxCode, ptxasPath, cc);
// compile ptx with ptxas
char _fsrc[L_tmpnam];
char _flog[L_tmpnam];
std::tmpnam(_fsrc);
std::tmpnam(_flog);
std::string fsrc = _fsrc;
std::string flog = _flog;
std::string fbin = fsrc + ".o";
const char *_fbin = fbin.c_str();
std::ofstream ofs(fsrc);
ofs << ptxCode << std::endl;
ofs.close();
std::string cmd;
int err;
cmd = ptxasPath + " -v --gpu-name=sm_" + std::to_string(capability) +
" " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str());
if (err != 0) {
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
unlink(_fsrc);
unlink(_flog);
throw std::runtime_error("Internal Triton PTX codegen error: \n" +
log);
}
std::ifstream _cubin(_fbin, std::ios::binary);
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
unlink(_fsrc);
unlink(_flog);
unlink(_fbin);
py::bytes bytes(cubin);
return bytes;
});
m.def(
"load_binary",
[](const std::string &name, const std::string &data,
size_t n_shared_bytes, uint64_t device) {
py::gil_scoped_release allow_threads;
// create driver handles
CUfunction fun;
CUmodule mod;
drv::dispatch::cuModuleLoadData(&mod, data.c_str());
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
// get allocated registers and spilled registers from the function
int n_regs = 0;
int n_spills = 0;
drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS,
fun);
drv::dispatch::cuFuncGetAttribute(
&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
n_spills /= 4;
// set dynamic shared memory if necessary
int shared_optin;
drv::dispatch::cuDeviceGetAttribute(
&shared_optin,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device);
if (n_shared_bytes > 49152 && shared_optin > 49152) {
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
int shared_total, shared_static;
drv::dispatch::cuDeviceGetAttribute(
&shared_total,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device);
drv::dispatch::cuFuncGetAttribute(
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
drv::dispatch::cuFuncSetAttribute(
fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static);
}
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs,
(uint64_t)n_spills);
},
py::return_value_policy::take_ownership);
}
void init_triton(py::module &m) {

View File

@@ -7,6 +7,7 @@ import hashlib
import io
import json
import os
import re
import shutil
import subprocess
import sys
@@ -843,7 +844,11 @@ def optimize_tritongpu_ir(mod, num_stages):
return mod
def make_ptx(mod: Any, device: int) -> Tuple[str, int]:
def make_llvm_ir(mod):
return _triton.translate_triton_gpu_to_llvmir(mod)
def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str, int]:
'''
Translate TritonGPU module to PTX code.
:param mod: a TritonGPU dialect module
@@ -851,17 +856,17 @@ def make_ptx(mod: Any, device: int) -> Tuple[str, int]:
- PTX code
- shared memory alloaction size
'''
return _triton.translate_triton_gpu_to_ptx(mod, device)
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
def make_cubin(ptx, device):
def make_cubin(ptx: str, ptxas: str, compute_capability: int):
'''
Compile TritonGPU module to cubin.
:param ptx: ptx code
:param device: CUDA device
:return: str
'''
return _triton.compile_ptx_to_cubin(ptx, device)
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
def ptx_get_kernel_name(ptx: str) -> str:
@@ -877,6 +882,46 @@ def ptx_get_kernel_name(ptx: str) -> str:
return line.split()[-1]
@functools.lru_cache
def ptx_get_version(cuda_version) -> int:
'''
Get the highest PTX version supported by the current CUDA driver.
'''
assert isinstance(cuda_version, str)
major, minor = map(int, cuda_version.split('.'))
version = major * 1000 + minor * 10
if version >= 11040:
return 74
if version >= 11030:
return 73
if version >= 11020:
return 72
if version >= 11010:
return 71
if version >= 11000:
return 70
if version >= 10020:
return 65
if version >= 10010:
return 64
if version >= 10000:
return 63
raise RuntimeError("Triton only support CUDA 10.0 or higher")
def path_to_ptxas():
prefixes = [os.environ.get("TRITON_PTXAS_PATH", ""), "", "/usr/local/cuda/"]
for prefix in prefixes:
ptxas = os.path.join(prefix, "bin", "ptxas")
if os.path.exists(ptxas):
result = subprocess.check_output([ptxas, "--version"], stderr=subprocess.STDOUT)
if result is not None:
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
if version is not None:
return ptxas, version.group(1)
raise RuntimeError("Cannot find ptxas")
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
@@ -895,17 +940,24 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat
# tritongpu-ir
module = make_tritongpu_ir(module, num_warps)
module = optimize_tritongpu_ir(module, num_stages)
if output == "ttgir":
return module.str()
# llvm-ir
llvm_ir = make_llvm_ir(module)
assert device >= 0, "device should be provided."
ptx, shem_size = make_ptx(module, device)
ptxas, cuda_version = path_to_ptxas()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
ptx_version = ptx_get_version(cuda_version)
ptx = make_ptx(llvm_ir, compute_capability, ptx_version)
shem_size = _triton.get_shared_memory_size(module)
kernel_name = ptx_get_kernel_name(ptx)
if output == "ptx":
return ptx, shem_size, kernel_name
cubin = make_cubin(ptx, device)
cubin = make_cubin(ptx, ptxas, compute_capability)
if output == "cubin":
return cubin, ptx, shem_size, kernel_name
@@ -980,6 +1032,7 @@ def generate_launcher(identifier, constants, signature):
src = f"""
#include \"cuda.h\"
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line)
{{
if (code != CUDA_SUCCESS)
@@ -993,13 +1046,16 @@ static inline void gpuAssert(CUresult code, const char *file, int line)
PyErr_SetString(PyExc_RuntimeError, err);
}}
}}
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
if(gridX*gridY*gridZ > 0){{
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
}}
}}
static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
if (PyLong_Check(obj)) {{
return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj);
@@ -1021,6 +1077,7 @@ static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
return (CUdeviceptr)0;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
int gridX, gridY, gridZ;
uint64_t _stream;
@@ -1039,10 +1096,12 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
Py_INCREF(Py_None);
return Py_None;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"launcher\",
@@ -1050,6 +1109,7 @@ static struct PyModuleDef ModuleDef = {{
-1, //size
ModuleMethods
}};
PyMODINIT_FUNC PyInit_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
@@ -1251,7 +1311,10 @@ class CompiledKernel:
self.asm["ptx"] = f.read()
device = torch.cuda.current_device()
mod, func, n_regs, n_spills = _triton.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
global cuda_utils
if cuda_utils is None:
cuda_utils = CudaUtils()
mod, func, n_regs, n_spills = cuda_utils.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
self.cu_module = mod
self.cu_function = func
@@ -1261,3 +1324,118 @@ class CompiledKernel:
stream = torch.cuda.current_stream().cuda_stream
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
return
class CudaUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(CudaUtils, cls).__new__(cls)
return cls.instance
def _generate_src(self):
return """
#include <cuda.h>
#include \"cuda.h\"
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line)
{
if (code != CUDA_SUCCESS)
{
const char* prefix = "Triton Error [CUDA]: ";
const char* str;
cuGetErrorString(code, &str);
char err[1024] = {0};
strcat(err, prefix);
strcat(err, str);
PyErr_SetString(PyExc_RuntimeError, err);
}
}
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); }
static PyObject* loadBinary(PyObject* self, PyObject* args) {
const char* name;
const char* data;
Py_ssize_t data_size;
int shared;
int device;
if(!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, &device)) {
return NULL;
}
CUfunction fun;
CUmodule mod;
int32_t n_regs = 0;
int32_t n_spills = 0;
Py_BEGIN_ALLOW_THREADS;
// create driver handles
CUDA_CHECK(cuModuleLoadData(&mod, data));
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
// get allocated registers and spilled registers from the function
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
CUDA_CHECK(cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
n_spills /= 4;
// set dynamic shared memory if necessary
int shared_optin;
CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device));
if (shared > 49152 && shared_optin > 49152) {
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
int shared_total, shared_static;
CUDA_CHECK(cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device));
CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK(cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static));
}
Py_END_ALLOW_THREADS;
if(PyErr_Occurred()) {
return NULL;
}
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, n_spills);
}
static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBinary, METH_VARARGS, "Load provided cubin into CUDA driver"},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef ModuleDef = {
PyModuleDef_HEAD_INIT,
\"cuda_utils\",
NULL, //documentation
-1, //size
ModuleMethods
};
PyMODINIT_FUNC PyInit_cuda_utils(void) {
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {
return NULL;
}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}
"""
def __init__(self):
src = self._generate_src()
key = hashlib.md5(src.encode("utf-8")).hexdigest()
cache = CacheManager(key)
fname = "cuda_utils.so"
if not cache.has_file(fname):
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build("cuda_utils", src_path, tmpdir)
with open(so, "rb") as f:
cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("cuda_utils", cache._make_path(fname))
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.load_binary = mod.load_binary
cuda_utils = None

View File

@@ -0,0 +1,61 @@
import argparse
import triton
import triton._C.libtriton.triton as libtriton
if __name__ == '__main__':
# valid source and target formats
VALID_FORMATS = ['llvm-ir', 'ptx', 'triton-ir', 'triton-gpu-ir']
# set up the argument parser
# TODO: conditional requirements
parser = argparse.ArgumentParser()
parser.add_argument('src', help="Source file to compile")
parser.add_argument('--target', required=True,
help="Target format, one of: " + ', '.join(VALID_FORMATS))
parser.add_argument('--sm', type=int, help="Compute capability to compile for")
parser.add_argument('--ptx-version', type=int, help="PTX version to compile for")
# parse the args
args = parser.parse_args()
# TODO: clean-up and re-use triton.compiler primitive functions
# check for validity of format arguments
if args.target not in VALID_FORMATS:
print("Invalid target format: " + args.target)
exit(0)
# parse source file to MLIR module
context = libtriton.ir.context()
module = libtriton.ir.parse_mlir_module(args.src, context)
module.context = context
# optimizer triton-ir
module = triton.compiler.optimize_triton_ir(module)
if args.target == 'triton-ir':
print(module.str())
exit(0)
# triton-ir -> triton-gpu-ir
module = triton.compiler.make_tritongpu_ir(module, num_warps=4)
module = triton.compiler.optimize_tritongpu_ir(module, num_stages=3)
if args.target == 'triton-gpu-ir':
print(module.str())
exit(0)
# triton-gpu-ir -> llvm-ir
module = triton.compiler.make_llvm_ir(module)
if args.target == 'llvm-ir':
print(module)
exit(0)
if not args.sm:
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
if not args.ptx_version:
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
# llvm-ir -> ptx
module = triton.compiler.make_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
assert args.target == 'ptx'
print(module)

View File

@@ -1,4 +1,4 @@
// RUN: triton-translate %s --target=llvmir | FileCheck %s
// RUN: python3 -m triton.tools.aot %s --target=llvm-ir | FileCheck %s
// == LLVM IR check begin ==
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'

View File

@@ -1,5 +1,4 @@
// RUN: triton-translate %s --target=ptx --sm=80 --ptx-version=10000 | FileCheck %s
// RUN: python3 -m triton.tools.aot %s --target=ptx --sm=80 --ptx-version=63 | FileCheck %s
// CHECK-LABEL: // Generated by LLVM NVPTX Back-End
// CHECK: .version 6.3
// CHECK: .target sm_80