Compare commits
59 Commits
legacy-bac
...
rocm
Author | SHA1 | Date | |
---|---|---|---|
|
46fd581b0a | ||
|
8cc448d92e | ||
|
9a9fabbba9 | ||
|
15886b5ffc | ||
|
d5830b4b6a | ||
|
bba1579485 | ||
|
cc6b5180c7 | ||
|
dfad6bdf36 | ||
|
f3bcbcfde6 | ||
|
7ec29a7453 | ||
|
4fb9d4904e | ||
|
4f3e2d6ed7 | ||
|
fecc7ce248 | ||
|
277b712284 | ||
|
d024f0cfb8 | ||
|
1811791665 | ||
|
9b3f2487b5 | ||
|
14730a2352 | ||
|
15683986cd | ||
|
48fcd8c987 | ||
|
8d9572bc63 | ||
|
ffb30cdc52 | ||
|
7fce2bc5f1 | ||
|
531ef18cb6 | ||
|
5f0d90db7e | ||
|
03ae41b310 | ||
|
bd61338b31 | ||
|
6e50f8b2c0 | ||
|
aa556d4f1b | ||
|
61e88efb23 | ||
|
ed9638801a | ||
|
8ecab462f6 | ||
|
648e4cfe89 | ||
|
abe0d3e1b1 | ||
|
4464dfcc18 | ||
|
0cae0168ec | ||
|
88d57ef9c9 | ||
|
39381d99f8 | ||
|
df925f7187 | ||
|
e84297ca79 | ||
|
61c85c18b2 | ||
|
da5c24ffcb | ||
|
09302f0106 | ||
|
9184b5cf65 | ||
|
8da4323514 | ||
|
eb89e9bdd9 | ||
|
56a06f7a06 | ||
|
6a31c43774 | ||
|
8785793445 | ||
|
d022f5cf2c | ||
|
4624fd4e1d | ||
|
41144f927f | ||
|
4d6d4c9431 | ||
|
32dbc08c05 | ||
|
4f21501def | ||
|
5c548fb57e | ||
|
fa4d0fd1ef | ||
|
406d03bfaf | ||
|
94d5c2e8b5 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -9,4 +9,5 @@ python/triton/_C/libtriton.pyd
|
||||
python/triton/_C/libtriton.so
|
||||
|
||||
.vscode
|
||||
.vs
|
||||
.vs
|
||||
log_*
|
@@ -3,6 +3,13 @@ include(ExternalProject)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
if(NOT TRITON_LLVM_BUILD_DIR)
|
||||
set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
|
||||
endif()
|
||||
|
||||
set(TRITON_USE_ROCM "$ENV{TRITON_USE_ROCM}")
|
||||
set(TRITON_ROCM_DEBUG "$ENV{TRITON_ROCM_DEBUG}")
|
||||
|
||||
project(triton)
|
||||
include(CTest)
|
||||
if(NOT WIN32)
|
||||
@@ -35,7 +42,11 @@ if(WIN32)
|
||||
add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
|
||||
if (TRITON_USE_ROCM)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-unused-result -Wno-attributes")
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
|
||||
endif()
|
||||
|
||||
|
||||
##########
|
||||
@@ -135,6 +146,13 @@ if(BUILD_PYTHON_MODULE)
|
||||
endif()
|
||||
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
|
||||
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
|
||||
if (TRITON_USE_ROCM)
|
||||
add_definitions(-DUSE_ROCM)
|
||||
endif()
|
||||
if (TRITON_ROCM_DEBUG)
|
||||
add_definitions(-DDEBUG_ROCM)
|
||||
endif()
|
||||
|
||||
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC})
|
||||
endif()
|
||||
|
||||
|
163
include/print_helper.h
Executable file
163
include/print_helper.h
Executable file
@@ -0,0 +1,163 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _PRINT_IR_H_
|
||||
#define _PRINT_IR_H_
|
||||
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/Support/CodeGen.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Support/TargetRegistry.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
#include "llvm/Analysis/TargetLibraryInfo.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include <iomanip>
|
||||
|
||||
#define PRINT_CURRENT_FUNCTION() std::cout << __FILE__ << ":" << __LINE__ << ":" << __FUNCTION__ << std::endl;
|
||||
|
||||
static int print_count = 0;
|
||||
|
||||
inline std::string return_current_time_and_date()
|
||||
{
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto in_time_t = std::chrono::system_clock::to_time_t(now);
|
||||
|
||||
std::stringstream ss;
|
||||
ss << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d--%I-%M-%S");
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void print_vector(std::vector<T> &vec, std::string name = "")
|
||||
{
|
||||
std::cout << name << ": ";
|
||||
for (auto v : vec)
|
||||
{
|
||||
std::cout << v << ", ";
|
||||
}
|
||||
|
||||
std::cout << '\b';
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
// dump llvm ir to tmp file
|
||||
inline std::string print_llvm_module(llvm::Module *llvm_module, bool print_to_cout = true)
|
||||
{
|
||||
std::cout << "\t" << "print_llvm_module" << std::endl;
|
||||
// get module as a string
|
||||
std::error_code ec;
|
||||
std::string mod_string;
|
||||
std::unique_ptr<llvm::raw_string_ostream> ir_ss(
|
||||
new llvm::raw_string_ostream(mod_string));
|
||||
llvm_module->print(*ir_ss, nullptr);
|
||||
|
||||
// print module
|
||||
if (print_to_cout)
|
||||
{
|
||||
if (!mod_string.empty())
|
||||
std::cout << "\t" << mod_string << std::endl;
|
||||
else
|
||||
std::cout << "\t" << llvm_module->getModuleIdentifier() << ": "
|
||||
<< "is empty" << std::endl;
|
||||
}
|
||||
|
||||
return mod_string;
|
||||
}
|
||||
|
||||
// dump llvm ir to tmp file
|
||||
inline void write_llvm_ir(llvm::Module *llvm_module, std::string filename = "", bool tracked = false)
|
||||
{
|
||||
|
||||
// get module string
|
||||
std::string module_string = print_llvm_module(llvm_module, false);
|
||||
|
||||
// get file name and path
|
||||
if (filename.empty())
|
||||
filename = llvm_module->getModuleIdentifier();
|
||||
std::string count_str = "";
|
||||
if (tracked)
|
||||
{
|
||||
count_str = "_" + std::to_string(print_count);
|
||||
}
|
||||
std::string ir_path = std::string("/tmp/") + filename + count_str + std::string(".ll");
|
||||
|
||||
// write file
|
||||
std::ofstream output_file(ir_path);
|
||||
output_file << module_string;
|
||||
output_file.close();
|
||||
|
||||
// increament counter
|
||||
if (tracked)
|
||||
{
|
||||
print_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
inline void print_triton_ir(triton::ir::module ir_ref, std::string name)
|
||||
{
|
||||
std::ofstream ir_out(std::string("/tmp/") + name + std::string("_") + return_current_time_and_date() + std::string(".ttir"));
|
||||
ir_out.flush();
|
||||
triton::ir::print(ir_ref, ir_out);
|
||||
ir_out.close();
|
||||
}
|
||||
|
||||
inline void print_triton_ir(std::string ir_ref, std::string name)
|
||||
{
|
||||
std::ofstream ir_out(std::string("/tmp/") + name + std::string("_") + return_current_time_and_date() + std::string(".ttir"));
|
||||
ir_out.flush();
|
||||
ir_out << ir_ref << std::endl;
|
||||
ir_out.close();
|
||||
}
|
||||
|
||||
inline std::string get_llvm_value_as_str(llvm::Value *llvm_value)
|
||||
{
|
||||
std::string value_str;
|
||||
llvm::raw_string_ostream rso(value_str);
|
||||
llvm_value->print(rso);
|
||||
return rso.str();
|
||||
}
|
||||
|
||||
inline void print_llvm_value(llvm::Value *llvm_value, std::string name = "")
|
||||
{
|
||||
if (llvm_value)
|
||||
std::cout << "\t" << name << ": " << get_llvm_value_as_str(llvm_value) << std::endl;
|
||||
else
|
||||
std::cout << "\t" << name << ": "
|
||||
<< "is nullptr" << std::endl;
|
||||
}
|
||||
|
||||
inline void print_llvm_type(llvm::Type *llvm_type, std::string name = "")
|
||||
{
|
||||
std::string type_str;
|
||||
llvm::raw_string_ostream rso(type_str);
|
||||
llvm_type->print(rso);
|
||||
std::cout << name << " type: " << rso.str() << std::endl;
|
||||
}
|
||||
|
||||
inline void print_llvm_value_type(llvm::Value *llvm_value, std::string name = "")
|
||||
{
|
||||
print_llvm_type(llvm_value->getType(), name);
|
||||
}
|
||||
|
||||
inline void write_ptx(std::string ptx_str)
|
||||
{
|
||||
std::ofstream file("/tmp/kernel.ptx");
|
||||
file << ptx_str;
|
||||
}
|
||||
#endif
|
@@ -36,6 +36,7 @@ namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
class nvidia_cu_target;
|
||||
class amd_cl_target;
|
||||
|
||||
class target {
|
||||
public:
|
||||
@@ -49,7 +50,12 @@ public:
|
||||
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual unsigned guaranteed_alignment() = 0;
|
||||
#ifdef USE_ROCM
|
||||
amd_cl_target* as_nvidia();
|
||||
amd_cl_target* as_amd();
|
||||
#else
|
||||
nvidia_cu_target* as_nvidia();
|
||||
#endif
|
||||
bool is_gpu() const;
|
||||
|
||||
private:
|
||||
@@ -67,6 +73,7 @@ public:
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
unsigned guaranteed_alignment() { return 16; }
|
||||
int sm() { return 0; } // treat as if old CUDA device
|
||||
};
|
||||
|
||||
class nvidia_cu_target: public target {
|
||||
|
@@ -11,7 +11,6 @@
|
||||
#include "triton/external/CUDA/nvml.h"
|
||||
|
||||
//// HIP backend
|
||||
//#define __HIP_PLATFORM_AMD__
|
||||
#include "triton/external/hip.h"
|
||||
|
||||
//Exceptions
|
||||
@@ -183,7 +182,8 @@ public:
|
||||
static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart, hipEvent_t hEnd);
|
||||
static hipError_t hipEventRecord(hipEvent_t hEvent, hipStream_t hStream);
|
||||
static hipError_t hipEventDestroy(hipEvent_t hEvent);
|
||||
|
||||
// error handling
|
||||
static hipError_t hipGetLastError(void);
|
||||
|
||||
|
||||
private:
|
||||
@@ -309,6 +309,8 @@ private:
|
||||
static void* hipEventElapsedTime_;
|
||||
static void* hipEventRecord_;
|
||||
static void* hipEventDestroy_;
|
||||
// error handling
|
||||
static void* hipGetLastError_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -13,8 +13,12 @@ std::string path_to_ptxas(int& version);
|
||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
|
||||
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
|
||||
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
|
||||
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
|
||||
std::tuple<std::string, std::string> llir_to_amdgcn(llvm::Module* module, const std::string& proc);
|
||||
hipModule_t amdgpu_to_hipmodule(const std::string& path);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#define STRINGIFY_HELPER(X) #X
|
||||
#define STRINGIFY(X) STRINGIFY_HELPER(X)
|
||||
|
||||
|
46
include/triton/external/CUDA/cuda.h
vendored
46
include/triton/external/CUDA/cuda.h
vendored
@@ -818,7 +818,7 @@ typedef enum CUcomputemode_enum {
|
||||
* Memory advise values
|
||||
*/
|
||||
typedef enum CUmem_advise_enum {
|
||||
CU_MEM_ADVISE_SET_READ_MOSTLY = 1, /**< Data will mostly be read and only occasionally be written to */
|
||||
CU_MEM_ADVISE_SET_READ_MOSTLY = 1, /**< Data will mostly be read and only occassionally be written to */
|
||||
CU_MEM_ADVISE_UNSET_READ_MOSTLY = 2, /**< Undo the effect of ::CU_MEM_ADVISE_SET_READ_MOSTLY */
|
||||
CU_MEM_ADVISE_SET_PREFERRED_LOCATION = 3, /**< Set the preferred location for the data as the specified device */
|
||||
CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION = 4, /**< Clear the preferred location for the data */
|
||||
@@ -827,7 +827,7 @@ typedef enum CUmem_advise_enum {
|
||||
} CUmem_advise;
|
||||
|
||||
typedef enum CUmem_range_attribute_enum {
|
||||
CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY = 1, /**< Whether the range will mostly be read and only occasionally be written to */
|
||||
CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY = 1, /**< Whether the range will mostly be read and only occassionally be written to */
|
||||
CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION = 2, /**< The preferred location of the range */
|
||||
CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY = 3, /**< Memory range has ::CU_MEM_ADVISE_SET_ACCESSED_BY set for specified device */
|
||||
CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION = 4 /**< The last location to which the range was prefetched */
|
||||
@@ -849,7 +849,7 @@ typedef enum CUjit_option_enum
|
||||
* IN: Specifies minimum number of threads per block to target compilation
|
||||
* for\n
|
||||
* OUT: Returns the number of threads the compiler actually targeted.
|
||||
* This restricts the resource utilization of the compiler (e.g. max
|
||||
* This restricts the resource utilization fo the compiler (e.g. max
|
||||
* registers) such that a block with the given number of threads should be
|
||||
* able to launch based on register limitations. Note, this option does not
|
||||
* currently take into account any other resource limitations, such as
|
||||
@@ -974,10 +974,10 @@ typedef enum CUjit_option_enum
|
||||
CU_JIT_FAST_COMPILE,
|
||||
|
||||
/**
|
||||
* Array of device symbol names that will be relocated to the corresponding
|
||||
* Array of device symbol names that will be relocated to the corresponing
|
||||
* host addresses stored in ::CU_JIT_GLOBAL_SYMBOL_ADDRESSES.\n
|
||||
* Must contain ::CU_JIT_GLOBAL_SYMBOL_COUNT entries.\n
|
||||
* When loading a device module, driver will relocate all encountered
|
||||
* When loding a device module, driver will relocate all encountered
|
||||
* unresolved symbols to the host addresses.\n
|
||||
* It is only allowed to register symbols that correspond to unresolved
|
||||
* global variables.\n
|
||||
@@ -1194,7 +1194,7 @@ typedef enum CUlimit_enum {
|
||||
* Resource types
|
||||
*/
|
||||
typedef enum CUresourcetype_enum {
|
||||
CU_RESOURCE_TYPE_ARRAY = 0x00, /**< Array resource */
|
||||
CU_RESOURCE_TYPE_ARRAY = 0x00, /**< Array resoure */
|
||||
CU_RESOURCE_TYPE_MIPMAPPED_ARRAY = 0x01, /**< Mipmapped array resource */
|
||||
CU_RESOURCE_TYPE_LINEAR = 0x02, /**< Linear resource */
|
||||
CU_RESOURCE_TYPE_PITCH2D = 0x03 /**< Pitch 2D resource */
|
||||
@@ -2914,9 +2914,9 @@ typedef struct CUmemAllocationProp_st {
|
||||
CUmemLocation location;
|
||||
/**
|
||||
* Windows-specific POBJECT_ATTRIBUTES required when
|
||||
* ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This object attributes structure
|
||||
* ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This object atributes structure
|
||||
* includes security attributes that define
|
||||
* the scope of which exported allocations may be transferred to other
|
||||
* the scope of which exported allocations may be tranferred to other
|
||||
* processes. In all other cases, this field is required to be zero.
|
||||
*/
|
||||
void *win32HandleMetaData;
|
||||
@@ -3036,7 +3036,7 @@ typedef struct CUmemPoolProps_st {
|
||||
/**
|
||||
* Windows-specific LPSECURITYATTRIBUTES required when
|
||||
* ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This security attribute defines
|
||||
* the scope of which exported allocations may be transferred to other
|
||||
* the scope of which exported allocations may be tranferred to other
|
||||
* processes. In all other cases, this field is required to be zero.
|
||||
*/
|
||||
void *win32SecurityAttributes;
|
||||
@@ -3519,7 +3519,7 @@ CUresult CUDAAPI cuDeviceGet(CUdevice *device, int ordinal);
|
||||
CUresult CUDAAPI cuDeviceGetCount(int *count);
|
||||
|
||||
/**
|
||||
* \brief Returns an identifier string for the device
|
||||
* \brief Returns an identifer string for the device
|
||||
*
|
||||
* Returns an ASCII string identifying the device \p dev in the NULL-terminated
|
||||
* string pointed to by \p name. \p len specifies the maximum length of the
|
||||
@@ -3556,7 +3556,7 @@ CUresult CUDAAPI cuDeviceGetName(char *name, int len, CUdevice dev);
|
||||
* Note there is a later version of this API, ::cuDeviceGetUuid_v2. It will
|
||||
* supplant this version in 12.0, which is retained for minor version compatibility.
|
||||
*
|
||||
* Returns 16-octets identifying the device \p dev in the structure
|
||||
* Returns 16-octets identifing the device \p dev in the structure
|
||||
* pointed by the \p uuid.
|
||||
*
|
||||
* \param uuid - Returned UUID
|
||||
@@ -3586,7 +3586,7 @@ CUresult CUDAAPI cuDeviceGetUuid(CUuuid *uuid, CUdevice dev);
|
||||
/**
|
||||
* \brief Return an UUID for the device (11.4+)
|
||||
*
|
||||
* Returns 16-octets identifying the device \p dev in the structure
|
||||
* Returns 16-octets identifing the device \p dev in the structure
|
||||
* pointed by the \p uuid. If the device is in MIG mode, returns its
|
||||
* MIG UUID which uniquely identifies the subscribed MIG compute instance.
|
||||
*
|
||||
@@ -3867,7 +3867,7 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements,
|
||||
* supports native atomic operations.
|
||||
* - ::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO: Ratio of single precision performance
|
||||
* (in floating-point operations per second) to double precision performance.
|
||||
* - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS: Device supports coherently accessing
|
||||
* - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS: Device suppports coherently accessing
|
||||
* pageable memory without calling cudaHostRegister on it.
|
||||
* - ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS: Device can coherently access managed memory
|
||||
* concurrently with the CPU.
|
||||
@@ -3875,7 +3875,7 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements,
|
||||
* - ::CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM: Device can access host registered
|
||||
* memory at the same virtual address as the CPU.
|
||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN: The maximum per block shared memory size
|
||||
* supported on this device. This is the maximum value that can be opted into when using the cuFuncSetAttribute() call.
|
||||
* suported on this device. This is the maximum value that can be opted into when using the cuFuncSetAttribute() call.
|
||||
* For more details see ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES
|
||||
* - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES: Device accesses pageable memory via the host's
|
||||
* page tables.
|
||||
@@ -4132,7 +4132,7 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuDeviceGetProperties(CUdevprop *prop, CUdevi
|
||||
*
|
||||
* \deprecated
|
||||
*
|
||||
* This function was deprecated as of CUDA 5.0 and its functionality superseded
|
||||
* This function was deprecated as of CUDA 5.0 and its functionality superceded
|
||||
* by ::cuDeviceGetAttribute().
|
||||
*
|
||||
* Returns in \p *major and \p *minor the major and minor revision numbers that
|
||||
@@ -4962,10 +4962,10 @@ CUresult CUDAAPI cuCtxSynchronize(void);
|
||||
* returned.
|
||||
*
|
||||
* - ::CU_LIMIT_MAX_L2_FETCH_GRANULARITY controls the L2 cache fetch granularity.
|
||||
* Values can range from 0B to 128B. This is purely a performance hint and
|
||||
* Values can range from 0B to 128B. This is purely a performence hint and
|
||||
* it can be ignored or clamped depending on the platform.
|
||||
*
|
||||
* - ::CU_LIMIT_PERSISTING_L2_CACHE_SIZE controls size in bytes available for
|
||||
* - ::CU_LIMIT_PERSISTING_L2_CACHE_SIZE controls size in bytes availabe for
|
||||
* persisting L2 cache. This is purely a performance hint and it can be
|
||||
* ignored or clamped depending on the platform.
|
||||
*
|
||||
@@ -6398,7 +6398,7 @@ CUresult CUDAAPI cuMemHostGetFlags(unsigned int *pFlags, void *p);
|
||||
* ::cuStreamAttachMemAsync will be required to enable access on such devices.
|
||||
*
|
||||
* If the association is later changed via ::cuStreamAttachMemAsync to
|
||||
* a single stream, the default association as specified during ::cuMemAllocManaged
|
||||
* a single stream, the default association as specifed during ::cuMemAllocManaged
|
||||
* is restored when that stream is destroyed. For __managed__ variables, the
|
||||
* default association is always ::CU_MEM_ATTACH_GLOBAL. Note that destroying a
|
||||
* stream is an asynchronous operation, and as a result, the change to default
|
||||
@@ -9616,13 +9616,13 @@ CUresult CUDAAPI cuMemAddressFree(CUdeviceptr ptr, size_t size);
|
||||
* \brief Create a CUDA memory handle representing a memory allocation of a given size described by the given properties
|
||||
*
|
||||
* This creates a memory allocation on the target device specified through the
|
||||
* \p prop structure. The created allocation will not have any device or host
|
||||
* \p prop strcuture. The created allocation will not have any device or host
|
||||
* mappings. The generic memory \p handle for the allocation can be
|
||||
* mapped to the address space of calling process via ::cuMemMap. This handle
|
||||
* cannot be transmitted directly to other processes (see
|
||||
* ::cuMemExportToShareableHandle). On Windows, the caller must also pass
|
||||
* an LPSECURITYATTRIBUTE in \p prop to be associated with this handle which
|
||||
* limits or allows access to this handle for a recipient process (see
|
||||
* limits or allows access to this handle for a recepient process (see
|
||||
* ::CUmemAllocationProp::win32HandleMetaData for more). The \p size of this
|
||||
* allocation must be a multiple of the the value given via
|
||||
* ::cuMemGetAllocationGranularity with the ::CU_MEM_ALLOC_GRANULARITY_MINIMUM
|
||||
@@ -9660,7 +9660,7 @@ CUresult CUDAAPI cuMemCreate(CUmemGenericAllocationHandle *handle, size_t size,
|
||||
* are unmapped and when all outstanding references to the handle (including it's
|
||||
* shareable counterparts) are also released. The generic memory handle can be
|
||||
* freed when there are still outstanding mappings made with this handle. Each
|
||||
* time a recipient process imports a shareable handle, it needs to pair it with
|
||||
* time a recepient process imports a shareable handle, it needs to pair it with
|
||||
* ::cuMemRelease for the handle to be freed. If \p handle is not a valid handle
|
||||
* the behavior is undefined.
|
||||
*
|
||||
@@ -10975,7 +10975,7 @@ CUresult CUDAAPI cuMemAdvise(CUdeviceptr devPtr, size_t count, CUmem_advise advi
|
||||
* a GPU id or CU_DEVICE_CPU depending on whether the last location for prefetch was a GPU or the CPU
|
||||
* respectively. If any page in the memory range was never explicitly prefetched or if all pages were not
|
||||
* prefetched to the same location, CU_DEVICE_INVALID will be returned. Note that this simply returns the
|
||||
* last location that the application requested to prefetch the memory range to. It gives no indication as to
|
||||
* last location that the applicaton requested to prefetch the memory range to. It gives no indication as to
|
||||
* whether the prefetch operation to that location has completed or even begun.
|
||||
*
|
||||
* \param data - A pointers to a memory location where the result
|
||||
@@ -13561,7 +13561,7 @@ CUresult CUDAAPI cuLaunchCooperativeKernel(CUfunction f,
|
||||
* All kernels launched must be identical with respect to the compiled code. Note that
|
||||
* any __device__, __constant__ or __managed__ variables present in the module that owns
|
||||
* the kernel launched on each device, are independently instantiated on every device.
|
||||
* It is the application's responsibility to ensure these variables are initialized and
|
||||
* It is the application's responsiblity to ensure these variables are initialized and
|
||||
* used appropriately.
|
||||
*
|
||||
* The size of the grids as specified in blocks, the size of the blocks themselves
|
||||
|
46
include/triton/external/CUDA/nvml.h
vendored
46
include/triton/external/CUDA/nvml.h
vendored
@@ -328,7 +328,7 @@ typedef enum nvmlGpuLevel_enum
|
||||
typedef enum nvmlGpuP2PStatus_enum
|
||||
{
|
||||
NVML_P2P_STATUS_OK = 0,
|
||||
NVML_P2P_STATUS_CHIPSET_NOT_SUPPORTED,
|
||||
NVML_P2P_STATUS_CHIPSET_NOT_SUPPORED,
|
||||
NVML_P2P_STATUS_GPU_NOT_SUPPORTED,
|
||||
NVML_P2P_STATUS_IOH_TOPOLOGY_NOT_SUPPORTED,
|
||||
NVML_P2P_STATUS_DISABLED_BY_REGKEY,
|
||||
@@ -736,7 +736,7 @@ typedef enum nvmlReturn_enum
|
||||
NVML_ERROR_IN_USE = 19, //!< An operation cannot be performed because the GPU is currently in use
|
||||
NVML_ERROR_MEMORY = 20, //!< Insufficient memory
|
||||
NVML_ERROR_NO_DATA = 21, //!<No data
|
||||
NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22, //!< The requested vgpu operation is not available on target device, because ECC is enabled
|
||||
NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22, //!< The requested vgpu operation is not available on target device, becasue ECC is enabled
|
||||
NVML_ERROR_UNKNOWN = 999 //!< An internal driver error occurred
|
||||
} nvmlReturn_t;
|
||||
|
||||
@@ -1463,7 +1463,7 @@ typedef struct nvmlEncoderSessionInfo_st
|
||||
*/
|
||||
typedef enum nvmlFBCSessionType_enum
|
||||
{
|
||||
NVML_FBC_SESSION_TYPE_UNKNOWN = 0, //!< Unknown
|
||||
NVML_FBC_SESSION_TYPE_UNKNOWN = 0, //!< Unknwon
|
||||
NVML_FBC_SESSION_TYPE_TOSYS, //!< ToSys
|
||||
NVML_FBC_SESSION_TYPE_CUDA, //!< Cuda
|
||||
NVML_FBC_SESSION_TYPE_VID, //!< Vid
|
||||
@@ -3678,10 +3678,10 @@ nvmlReturn_t DECLDIR nvmlDeviceGetEncoderStats (nvmlDevice_t device, unsigned in
|
||||
* Retrieves information about active encoder sessions on a target device.
|
||||
*
|
||||
* An array of active encoder sessions is returned in the caller-supplied buffer pointed at by \a sessionInfos. The
|
||||
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
|
||||
* array elememt count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
|
||||
* written to the buffer.
|
||||
*
|
||||
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
|
||||
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
|
||||
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlEncoderSessionInfo_t array required in \a sessionCount.
|
||||
* To query the number of active encoder sessions, call this function with *sessionCount = 0. The code will return
|
||||
* NVML_SUCCESS with number of active encoder sessions updated in *sessionCount.
|
||||
@@ -3727,7 +3727,7 @@ nvmlReturn_t DECLDIR nvmlDeviceGetDecoderUtilization(nvmlDevice_t device, unsign
|
||||
* For Maxwell &tm; or newer fully supported devices.
|
||||
*
|
||||
* @param device The identifier of the target device
|
||||
* @param fbcStats Reference to nvmlFBCStats_t structure containing NvFBC stats
|
||||
* @param fbcStats Reference to nvmlFBCStats_t structure contianing NvFBC stats
|
||||
*
|
||||
* @return
|
||||
* - \ref NVML_SUCCESS if \a fbcStats is fetched
|
||||
@@ -3742,10 +3742,10 @@ nvmlReturn_t DECLDIR nvmlDeviceGetFBCStats(nvmlDevice_t device, nvmlFBCStats_t *
|
||||
* Retrieves information about active frame buffer capture sessions on a target device.
|
||||
*
|
||||
* An array of active encoder sessions is returned in the caller-supplied buffer pointed at by \a sessionInfo. The
|
||||
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
|
||||
* array elememt count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
|
||||
* written to the buffer.
|
||||
*
|
||||
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
|
||||
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
|
||||
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlFBCSessionInfo_t array required in \a sessionCount.
|
||||
* To query the number of active FBC sessions, call this function with *sessionCount = 0. The code will return
|
||||
* NVML_SUCCESS with number of active FBC sessions updated in *sessionCount.
|
||||
@@ -4208,7 +4208,7 @@ nvmlReturn_t DECLDIR nvmlDeviceGetRetiredPages(nvmlDevice_t device, nvmlPageReti
|
||||
* The address information provided from this API is the hardware address of the page that was retired. Note
|
||||
* that this does not match the virtual address used in CUDA, but will match the address information in XID 63
|
||||
*
|
||||
* \note nvmlDeviceGetRetiredPages_v2 adds an additional timestamps parameter to return the time of each page's
|
||||
* \note nvmlDeviceGetRetiredPages_v2 adds an additional timestamps paramter to return the time of each page's
|
||||
* retirement.
|
||||
*
|
||||
* For Kepler &tm; or newer fully supported devices.
|
||||
@@ -4476,7 +4476,7 @@ nvmlReturn_t DECLDIR nvmlDeviceSetDriverModel(nvmlDevice_t device, nvmlDriverMod
|
||||
* Set clocks that device will lock to.
|
||||
*
|
||||
* Sets the clocks that the device will be running at to the value in the range of minGpuClockMHz to maxGpuClockMHz.
|
||||
* Setting this will supersede application clock values and take effect regardless if a cuda app is running.
|
||||
* Setting this will supercede application clock values and take effect regardless if a cuda app is running.
|
||||
* See /ref nvmlDeviceSetApplicationsClocks
|
||||
*
|
||||
* Can be used as a setting to request constant performance.
|
||||
@@ -5297,7 +5297,7 @@ nvmlReturn_t DECLDIR nvmlDeviceSetVirtualizationMode(nvmlDevice_t device, nvmlGp
|
||||
* pointed at by \a vgpuTypeIds. The element count of nvmlVgpuTypeId_t array is passed in \a vgpuCount, and \a vgpuCount
|
||||
* is used to return the number of vGPU types written to the buffer.
|
||||
*
|
||||
* If the supplied buffer is not large enough to accommodate the vGPU type array, the function returns
|
||||
* If the supplied buffer is not large enough to accomodate the vGPU type array, the function returns
|
||||
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlVgpuTypeId_t array required in \a vgpuCount.
|
||||
* To query the number of vGPU types supported for the GPU, call this function with *vgpuCount = 0.
|
||||
* The code will return NVML_ERROR_INSUFFICIENT_SIZE, or NVML_SUCCESS if no vGPU types are supported.
|
||||
@@ -5327,9 +5327,9 @@ nvmlReturn_t DECLDIR nvmlDeviceGetSupportedVgpus(nvmlDevice_t device, unsigned i
|
||||
* can concurrently run on a device. For example, if only one vGPU type is allowed at a time on a device, then the creatable
|
||||
* list will be restricted to whatever vGPU type is already running on the device.
|
||||
*
|
||||
* If the supplied buffer is not large enough to accommodate the vGPU type array, the function returns
|
||||
* If the supplied buffer is not large enough to accomodate the vGPU type array, the function returns
|
||||
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlVgpuTypeId_t array required in \a vgpuCount.
|
||||
* To query the number of vGPU types creatable for the GPU, call this function with *vgpuCount = 0.
|
||||
* To query the number of vGPU types createable for the GPU, call this function with *vgpuCount = 0.
|
||||
* The code will return NVML_ERROR_INSUFFICIENT_SIZE, or NVML_SUCCESS if no vGPU types are creatable.
|
||||
*
|
||||
* @param device The identifier of the target device
|
||||
@@ -5392,7 +5392,7 @@ nvmlReturn_t DECLDIR nvmlVgpuTypeGetName(nvmlVgpuTypeId_t vgpuTypeId, char *vgpu
|
||||
*
|
||||
* @param vgpuTypeId Handle to vGPU type
|
||||
* @param deviceID Device ID and vendor ID of the device contained in single 32 bit value
|
||||
* @param subsystemID subsystem ID and subsystem vendor ID of the device contained in single 32 bit value
|
||||
* @param subsystemID Subsytem ID and subsytem vendor ID of the device contained in single 32 bit value
|
||||
*
|
||||
* @return
|
||||
* - \ref NVML_SUCCESS successful completion
|
||||
@@ -5516,10 +5516,10 @@ nvmlReturn_t DECLDIR nvmlVgpuTypeGetMaxInstances(nvmlDevice_t device, nvmlVgpuTy
|
||||
* Retrieve the active vGPU instances on a device.
|
||||
*
|
||||
* An array of active vGPU instances is returned in the caller-supplied buffer pointed at by \a vgpuInstances. The
|
||||
* array element count is passed in \a vgpuCount, and \a vgpuCount is used to return the number of vGPU instances
|
||||
* array elememt count is passed in \a vgpuCount, and \a vgpuCount is used to return the number of vGPU instances
|
||||
* written to the buffer.
|
||||
*
|
||||
* If the supplied buffer is not large enough to accommodate the vGPU instance array, the function returns
|
||||
* If the supplied buffer is not large enough to accomodate the vGPU instance array, the function returns
|
||||
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlVgpuInstance_t array required in \a vgpuCount.
|
||||
* To query the number of active vGPU instances, call this function with *vgpuCount = 0. The code will return
|
||||
* NVML_ERROR_INSUFFICIENT_SIZE, or NVML_SUCCESS if no vGPU Types are supported.
|
||||
@@ -5702,7 +5702,7 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetFrameRateLimit(nvmlVgpuInstance_t vgpuIn
|
||||
* @param encoderCapacity Reference to an unsigned int for the encoder capacity
|
||||
*
|
||||
* @return
|
||||
* - \ref NVML_SUCCESS if \a encoderCapacity has been retrieved
|
||||
* - \ref NVML_SUCCESS if \a encoderCapacity has been retrived
|
||||
* - \ref NVML_ERROR_UNINITIALIZED if the library has not been successfully initialized
|
||||
* - \ref NVML_ERROR_INVALID_ARGUMENT if \a vgpuInstance is 0, or \a encoderQueryType is invalid
|
||||
* - \ref NVML_ERROR_NOT_FOUND if \a vgpuInstance does not match a valid active vGPU instance on the system
|
||||
@@ -5863,10 +5863,10 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetEncoderStats(nvmlVgpuInstance_t vgpuInst
|
||||
* Retrieves information about all active encoder sessions on a vGPU Instance.
|
||||
*
|
||||
* An array of active encoder sessions is returned in the caller-supplied buffer pointed at by \a sessionInfo. The
|
||||
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
|
||||
* array elememt count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
|
||||
* written to the buffer.
|
||||
*
|
||||
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
|
||||
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
|
||||
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlEncoderSessionInfo_t array required in \a sessionCount.
|
||||
* To query the number of active encoder sessions, call this function with *sessionCount = 0. The code will return
|
||||
* NVML_SUCCESS with number of active encoder sessions updated in *sessionCount.
|
||||
@@ -5896,7 +5896,7 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetEncoderSessions(nvmlVgpuInstance_t vgpuI
|
||||
* For Maxwell &tm; or newer fully supported devices.
|
||||
*
|
||||
* @param vgpuInstance Identifier of the target vGPU instance
|
||||
* @param fbcStats Reference to nvmlFBCStats_t structure containing NvFBC stats
|
||||
* @param fbcStats Reference to nvmlFBCStats_t structure contianing NvFBC stats
|
||||
*
|
||||
* @return
|
||||
* - \ref NVML_SUCCESS if \a fbcStats is fetched
|
||||
@@ -5914,7 +5914,7 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetFBCStats(nvmlVgpuInstance_t vgpuInstance
|
||||
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
|
||||
* written to the buffer.
|
||||
*
|
||||
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
|
||||
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
|
||||
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlFBCSessionInfo_t array required in \a sessionCount.
|
||||
* To query the number of active FBC sessions, call this function with *sessionCount = 0. The code will return
|
||||
* NVML_SUCCESS with number of active FBC sessions updated in *sessionCount.
|
||||
@@ -6094,7 +6094,7 @@ typedef struct nvmlVgpuPgpuMetadata_st
|
||||
unsigned int version; //!< Current version of the structure
|
||||
unsigned int revision; //!< Current revision of the structure
|
||||
char hostDriverVersion[NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE]; //!< Host driver version
|
||||
unsigned int pgpuVirtualizationCaps; //!< Pgpu virtualization capabilities bitfield
|
||||
unsigned int pgpuVirtualizationCaps; //!< Pgpu virtualizaion capabilities bitfileld
|
||||
unsigned int reserved[7]; //!< Reserved for internal use
|
||||
unsigned int opaqueDataSize; //!< Size of opaque data field in bytes
|
||||
char opaqueData[4]; //!< Opaque data
|
||||
@@ -6191,7 +6191,7 @@ nvmlReturn_t DECLDIR nvmlDeviceGetVgpuMetadata(nvmlDevice_t device, nvmlVgpuPgpu
|
||||
*
|
||||
* The caller passes in a buffer via \a compatibilityInfo, into which a compatibility information structure is written. The
|
||||
* structure defines the states in which the vGPU / VM may be booted on the physical GPU. If the vGPU / VM compatibility
|
||||
* with the physical GPU is limited, a limit code indicates the factor limiting compatibility.
|
||||
* with the physical GPU is limited, a limit code indicates the factor limiting compability.
|
||||
* (see \ref nvmlVgpuPgpuCompatibilityLimitCode_t for details).
|
||||
*
|
||||
* Note: vGPU compatibility does not take into account dynamic capacity conditions that may limit a system's ability to
|
||||
|
18
include/triton/external/half.hpp
vendored
18
include/triton/external/half.hpp
vendored
@@ -950,7 +950,7 @@ namespace half_float
|
||||
/// Convert half-precision floating point to integer.
|
||||
/// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding
|
||||
/// \tparam E `true` for round to even, `false` for round away from zero
|
||||
/// \tparam T type to convert to (builtin integer type with at least 16 bits precision, excluding any implicit sign bits)
|
||||
/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits)
|
||||
/// \param value binary representation of half-precision value
|
||||
/// \return integral value
|
||||
template<std::float_round_style R,bool E,typename T> T half2int_impl(uint16 value)
|
||||
@@ -988,13 +988,13 @@ namespace half_float
|
||||
|
||||
/// Convert half-precision floating point to integer.
|
||||
/// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding
|
||||
/// \tparam T type to convert to (builtin integer type with at least 16 bits precision, excluding any implicit sign bits)
|
||||
/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits)
|
||||
/// \param value binary representation of half-precision value
|
||||
/// \return integral value
|
||||
template<std::float_round_style R,typename T> T half2int(uint16 value) { return half2int_impl<R,HALF_ROUND_TIES_TO_EVEN,T>(value); }
|
||||
|
||||
/// Convert half-precision floating point to integer using round-to-nearest-away-from-zero.
|
||||
/// \tparam T type to convert to (builtin integer type with at least 16 bits precision, excluding any implicit sign bits)
|
||||
/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits)
|
||||
/// \param value binary representation of half-precision value
|
||||
/// \return integral value
|
||||
template<typename T> T half2int_up(uint16 value) { return half2int_impl<std::round_to_nearest,0,T>(value); }
|
||||
@@ -1053,7 +1053,7 @@ namespace half_float
|
||||
|
||||
/// Half-precision floating point type.
|
||||
/// This class implements an IEEE-conformant half-precision floating point type with the usual arithmetic operators and
|
||||
/// conversions. It is implicitly convertible to single-precision floating point, which makes arithmetic expressions and
|
||||
/// conversions. It is implicitly convertible to single-precision floating point, which makes artihmetic expressions and
|
||||
/// functions with mixed-type operands to be of the most precise operand type. Additionally all arithmetic operations
|
||||
/// (and many mathematical functions) are carried out in single-precision internally. All conversions from single- to
|
||||
/// half-precision are done using the library's default rounding mode, but temporary results inside chained arithmetic
|
||||
@@ -1062,7 +1062,7 @@ namespace half_float
|
||||
/// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's less strict and
|
||||
/// extended definitions it is both a standard layout type and a trivially copyable type (even if not a POD type), which
|
||||
/// means it can be standard-conformantly copied using raw binary copies. But in this context some more words about the
|
||||
/// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not necessarily have to be of
|
||||
/// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not neccessarily have to be of
|
||||
/// exactly 16-bits size. But on any reasonable implementation the actual binary representation of this type will most
|
||||
/// probably not ivolve any additional "magic" or padding beyond the simple binary representation of the underlying 16-bit
|
||||
/// IEEE number, even if not strictly guaranteed by the standard. But even then it only has an actual size of 16 bits if
|
||||
@@ -2181,7 +2181,7 @@ namespace half_float
|
||||
|
||||
/// Identity.
|
||||
/// \param arg operand
|
||||
/// \return unchanged operand
|
||||
/// \return uncahnged operand
|
||||
template<typename T> HALF_CONSTEXPR typename enable<T,T>::type operator+(T arg) { return arg; }
|
||||
|
||||
/// Negation.
|
||||
@@ -2620,7 +2620,7 @@ namespace half_float
|
||||
/// Multiply by power of two.
|
||||
/// \param arg number to modify
|
||||
/// \param exp power of two to multiply with
|
||||
/// \return \a arg multiplied by 2 raised to \a exp
|
||||
/// \return \a arg multplied by 2 raised to \a exp
|
||||
// template<typename T> typename enable<half,T>::type ldexp(T arg, int exp) { return functions::scalbln(arg, exp); }
|
||||
inline half ldexp(half arg, int exp) { return functions::scalbln(arg, exp); }
|
||||
inline half ldexp(expr arg, int exp) { return functions::scalbln(arg, exp); }
|
||||
@@ -2636,7 +2636,7 @@ namespace half_float
|
||||
/// Multiply by power of two.
|
||||
/// \param arg number to modify
|
||||
/// \param exp power of two to multiply with
|
||||
/// \return \a arg multiplied by 2 raised to \a exp
|
||||
/// \return \a arg multplied by 2 raised to \a exp
|
||||
// template<typename T> typename enable<half,T>::type scalbn(T arg, int exp) { return functions::scalbln(arg, exp); }
|
||||
inline half scalbn(half arg, int exp) { return functions::scalbln(arg, exp); }
|
||||
inline half scalbn(expr arg, int exp) { return functions::scalbln(arg, exp); }
|
||||
@@ -2644,7 +2644,7 @@ namespace half_float
|
||||
/// Multiply by power of two.
|
||||
/// \param arg number to modify
|
||||
/// \param exp power of two to multiply with
|
||||
/// \return \a arg multiplied by 2 raised to \a exp
|
||||
/// \return \a arg multplied by 2 raised to \a exp
|
||||
// template<typename T> typename enable<half,T>::type scalbln(T arg, long exp) { return functions::scalbln(arg, exp); }
|
||||
inline half scalbln(half arg, long exp) { return functions::scalbln(arg, exp); }
|
||||
inline half scalbln(expr arg, long exp) { return functions::scalbln(arg, exp); }
|
||||
|
319
include/triton/external/hip.h
vendored
319
include/triton/external/hip.h
vendored
@@ -1,13 +1,35 @@
|
||||
/*
|
||||
* @brief hipError_t
|
||||
* @enum
|
||||
* @ingroup Enumerations
|
||||
*/
|
||||
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
|
||||
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
|
||||
Copyright (c) 2015 - 2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef HIP_H
|
||||
#define HIP_H
|
||||
|
||||
// Ignoring error-code return values from hip APIs is discouraged. On C++17,
|
||||
// we can make that yield a warning
|
||||
#if __cplusplus >= 201703L
|
||||
#define __HIP_NODISCARD [[nodiscard]]
|
||||
#else
|
||||
#define __HIP_NODISCARD
|
||||
#endif
|
||||
|
||||
/*
|
||||
* @brief hipError_t
|
||||
@@ -17,9 +39,7 @@
|
||||
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
|
||||
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
typedef enum hipError_t {
|
||||
typedef enum __HIP_NODISCARD hipError_t {
|
||||
hipSuccess = 0, ///< Successful completion.
|
||||
hipErrorInvalidValue = 1, ///< One or more of the parameters passed to the API call is NULL
|
||||
///< or not in an acceptable range.
|
||||
@@ -73,6 +93,7 @@ typedef enum hipError_t {
|
||||
hipErrorInvalidHandle = 400,
|
||||
// Deprecated
|
||||
hipErrorInvalidResourceHandle = 400, ///< Resource handle (hipEvent_t or hipStream_t) invalid.
|
||||
hipErrorIllegalState = 401, ///< Resource required is not in a valid state to perform operation.
|
||||
hipErrorNotFound = 500,
|
||||
hipErrorNotReady = 600, ///< Indicates that asynchronous operations enqueued earlier are not
|
||||
///< ready. This is not actually an error, but is used to distinguish
|
||||
@@ -86,6 +107,7 @@ typedef enum hipError_t {
|
||||
hipErrorPeerAccessNotEnabled =
|
||||
705, ///< Peer access was never enabled from the current device.
|
||||
hipErrorSetOnActiveProcess = 708,
|
||||
hipErrorContextIsDestroyed = 709,
|
||||
hipErrorAssert = 710, ///< Produced when the kernel calls assert.
|
||||
hipErrorHostMemoryAlreadyRegistered =
|
||||
712, ///< Produced when trying to lock a page-locked memory.
|
||||
@@ -98,6 +120,32 @@ typedef enum hipError_t {
|
||||
///< that was launched via cooperative launch APIs exceeds the maximum number of
|
||||
///< allowed blocks for the current device
|
||||
hipErrorNotSupported = 801, ///< Produced when the hip API is not supported/implemented
|
||||
hipErrorStreamCaptureUnsupported = 900, ///< The operation is not permitted when the stream
|
||||
///< is capturing.
|
||||
hipErrorStreamCaptureInvalidated = 901, ///< The current capture sequence on the stream
|
||||
///< has been invalidated due to a previous error.
|
||||
hipErrorStreamCaptureMerge = 902, ///< The operation would have resulted in a merge of
|
||||
///< two independent capture sequences.
|
||||
hipErrorStreamCaptureUnmatched = 903, ///< The capture was not initiated in this stream.
|
||||
hipErrorStreamCaptureUnjoined = 904, ///< The capture sequence contains a fork that was not
|
||||
///< joined to the primary stream.
|
||||
hipErrorStreamCaptureIsolation = 905, ///< A dependency would have been created which crosses
|
||||
///< the capture sequence boundary. Only implicit
|
||||
///< in-stream ordering dependencies are allowed
|
||||
///< to cross the boundary
|
||||
hipErrorStreamCaptureImplicit = 906, ///< The operation would have resulted in a disallowed
|
||||
///< implicit dependency on a current capture sequence
|
||||
///< from hipStreamLegacy.
|
||||
hipErrorCapturedEvent = 907, ///< The operation is not permitted on an event which was last
|
||||
///< recorded in a capturing stream.
|
||||
hipErrorStreamCaptureWrongThread = 908, ///< A stream capture sequence not initiated with
|
||||
///< the hipStreamCaptureModeRelaxed argument to
|
||||
///< hipStreamBeginCapture was passed to
|
||||
///< hipStreamEndCapture in a different thread.
|
||||
hipErrorGraphExecUpdateFailure = 910, ///< This error indicates that the graph update
|
||||
///< not performed because it included changes which
|
||||
///< violated constraintsspecific to instantiated graph
|
||||
///< update.
|
||||
hipErrorUnknown = 999, //< Unknown error.
|
||||
// HSA Runtime Error Codes start here.
|
||||
hipErrorRuntimeMemory = 1052, ///< HSA runtime memory call returned error. Typically not seen
|
||||
@@ -107,35 +155,154 @@ typedef enum hipError_t {
|
||||
hipErrorTbd ///< Marker that more error codes are needed.
|
||||
} hipError_t;
|
||||
|
||||
#undef __HIP_NODISCARD
|
||||
|
||||
/*
|
||||
* @brief hipDeviceAttribute_t
|
||||
* @enum
|
||||
* @ingroup Enumerations
|
||||
*/
|
||||
typedef enum hipDeviceAttribute_t {
|
||||
hipDeviceAttributeCudaCompatibleBegin = 0,
|
||||
|
||||
hipDeviceAttributeEccEnabled = hipDeviceAttributeCudaCompatibleBegin, ///< Whether ECC support is enabled.
|
||||
hipDeviceAttributeAccessPolicyMaxWindowSize, ///< Cuda only. The maximum size of the window policy in bytes.
|
||||
hipDeviceAttributeAsyncEngineCount, ///< Cuda only. Asynchronous engines number.
|
||||
hipDeviceAttributeCanMapHostMemory, ///< Whether host memory can be mapped into device address space
|
||||
hipDeviceAttributeCanUseHostPointerForRegisteredMem,///< Cuda only. Device can access host registered memory
|
||||
///< at the same virtual address as the CPU
|
||||
hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz.
|
||||
hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in.
|
||||
hipDeviceAttributeComputePreemptionSupported, ///< Cuda only. Device supports Compute Preemption.
|
||||
hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple kernels concurrently.
|
||||
hipDeviceAttributeConcurrentManagedAccess, ///< Device can coherently access managed memory concurrently with the CPU
|
||||
hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch
|
||||
hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative launch on multiple devices
|
||||
hipDeviceAttributeDeviceOverlap, ///< Cuda only. Device can concurrently copy memory and execute a kernel.
|
||||
///< Deprecated. Use instead asyncEngineCount.
|
||||
hipDeviceAttributeDirectManagedMemAccessFromHost, ///< Host can directly access managed memory on
|
||||
///< the device without migration
|
||||
hipDeviceAttributeGlobalL1CacheSupported, ///< Cuda only. Device supports caching globals in L1
|
||||
hipDeviceAttributeHostNativeAtomicSupported, ///< Cuda only. Link between the device and the host supports native atomic operations
|
||||
hipDeviceAttributeIntegrated, ///< Device is integrated GPU
|
||||
hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices.
|
||||
hipDeviceAttributeKernelExecTimeout, ///< Run time limit for kernels executed on the device
|
||||
hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device doesn't have L2 cache.
|
||||
hipDeviceAttributeLocalL1CacheSupported, ///< caching locals in L1 is supported
|
||||
hipDeviceAttributeLuid, ///< Cuda only. 8-byte locally unique identifier in 8 bytes. Undefined on TCC and non-Windows platforms
|
||||
hipDeviceAttributeLuidDeviceNodeMask, ///< Cuda only. Luid device node mask. Undefined on TCC and non-Windows platforms
|
||||
hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability version number.
|
||||
hipDeviceAttributeManagedMemory, ///< Device supports allocating managed memory on this system
|
||||
hipDeviceAttributeMaxBlocksPerMultiProcessor, ///< Cuda only. Max block size per multiprocessor
|
||||
hipDeviceAttributeMaxBlockDimX, ///< Max block size in width.
|
||||
hipDeviceAttributeMaxBlockDimY, ///< Max block size in height.
|
||||
hipDeviceAttributeMaxBlockDimZ, ///< Max block size in depth.
|
||||
hipDeviceAttributeMaxGridDimX, ///< Max grid size in width.
|
||||
hipDeviceAttributeMaxGridDimY, ///< Max grid size in height.
|
||||
hipDeviceAttributeMaxGridDimZ, ///< Max grid size in depth.
|
||||
hipDeviceAttributeMaxSurface1D, ///< Maximum size of 1D surface.
|
||||
hipDeviceAttributeMaxSurface1DLayered, ///< Cuda only. Maximum dimensions of 1D layered surface.
|
||||
hipDeviceAttributeMaxSurface2D, ///< Maximum dimension (width, height) of 2D surface.
|
||||
hipDeviceAttributeMaxSurface2DLayered, ///< Cuda only. Maximum dimensions of 2D layered surface.
|
||||
hipDeviceAttributeMaxSurface3D, ///< Maximum dimension (width, height, depth) of 3D surface.
|
||||
hipDeviceAttributeMaxSurfaceCubemap, ///< Cuda only. Maximum dimensions of Cubemap surface.
|
||||
hipDeviceAttributeMaxSurfaceCubemapLayered, ///< Cuda only. Maximum dimension of Cubemap layered surface.
|
||||
hipDeviceAttributeMaxTexture1DWidth, ///< Maximum size of 1D texture.
|
||||
hipDeviceAttributeMaxTexture1DLayered, ///< Cuda only. Maximum dimensions of 1D layered texture.
|
||||
hipDeviceAttributeMaxTexture1DLinear, ///< Maximum number of elements allocatable in a 1D linear texture.
|
||||
///< Use cudaDeviceGetTexture1DLinearMaxWidth() instead on Cuda.
|
||||
hipDeviceAttributeMaxTexture1DMipmap, ///< Cuda only. Maximum size of 1D mipmapped texture.
|
||||
hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D texture.
|
||||
hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension hight of 2D texture.
|
||||
hipDeviceAttributeMaxTexture2DGather, ///< Cuda only. Maximum dimensions of 2D texture if gather operations performed.
|
||||
hipDeviceAttributeMaxTexture2DLayered, ///< Cuda only. Maximum dimensions of 2D layered texture.
|
||||
hipDeviceAttributeMaxTexture2DLinear, ///< Cuda only. Maximum dimensions (width, height, pitch) of 2D textures bound to pitched memory.
|
||||
hipDeviceAttributeMaxTexture2DMipmap, ///< Cuda only. Maximum dimensions of 2D mipmapped texture.
|
||||
hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D texture.
|
||||
hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimension height of 3D texture.
|
||||
hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimension depth of 3D texture.
|
||||
hipDeviceAttributeMaxTexture3DAlt, ///< Cuda only. Maximum dimensions of alternate 3D texture.
|
||||
hipDeviceAttributeMaxTextureCubemap, ///< Cuda only. Maximum dimensions of Cubemap texture
|
||||
hipDeviceAttributeMaxTextureCubemapLayered, ///< Cuda only. Maximum dimensions of Cubemap layered texture.
|
||||
hipDeviceAttributeMaxThreadsDim, ///< Maximum dimension of a block
|
||||
hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per block.
|
||||
hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads per multiprocessor.
|
||||
hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory copies
|
||||
hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits.
|
||||
hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in kilohertz.
|
||||
hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability version number.
|
||||
hipDeviceAttributeMultiGpuBoardGroupID, ///< Cuda only. Unique ID of device group on the same multi-GPU board
|
||||
hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the device.
|
||||
hipDeviceAttributeName, ///< Device name.
|
||||
hipDeviceAttributePageableMemoryAccess, ///< Device supports coherently accessing pageable memory
|
||||
///< without calling hipHostRegister on it
|
||||
hipDeviceAttributePageableMemoryAccessUsesHostPageTables, ///< Device accesses pageable memory via the host's page tables
|
||||
hipDeviceAttributePciBusId, ///< PCI Bus ID.
|
||||
hipDeviceAttributePciDeviceId, ///< PCI Device ID.
|
||||
hipDeviceAttributePciDomainID, ///< PCI Domain ID.
|
||||
hipDeviceAttributePersistingL2CacheMaxSize, ///< Cuda11 only. Maximum l2 persisting lines capacity in bytes
|
||||
hipDeviceAttributeMaxRegistersPerBlock, ///< 32-bit registers available to a thread block. This number is shared
|
||||
///< by all thread blocks simultaneously resident on a multiprocessor.
|
||||
hipDeviceAttributeMaxRegistersPerMultiprocessor, ///< 32-bit registers available per block.
|
||||
hipDeviceAttributeReservedSharedMemPerBlock, ///< Cuda11 only. Shared memory reserved by CUDA driver per block.
|
||||
hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory available per block in bytes.
|
||||
hipDeviceAttributeSharedMemPerBlockOptin, ///< Cuda only. Maximum shared memory per block usable by special opt in.
|
||||
hipDeviceAttributeSharedMemPerMultiprocessor, ///< Cuda only. Shared memory available per multiprocessor.
|
||||
hipDeviceAttributeSingleToDoublePrecisionPerfRatio, ///< Cuda only. Performance ratio of single precision to double precision.
|
||||
hipDeviceAttributeStreamPrioritiesSupported, ///< Cuda only. Whether to support stream priorities.
|
||||
hipDeviceAttributeSurfaceAlignment, ///< Cuda only. Alignment requirement for surfaces
|
||||
hipDeviceAttributeTccDriver, ///< Cuda only. Whether device is a Tesla device using TCC driver
|
||||
hipDeviceAttributeTextureAlignment, ///< Alignment requirement for textures
|
||||
hipDeviceAttributeTexturePitchAlignment, ///< Pitch alignment requirement for 2D texture references bound to pitched memory;
|
||||
hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes.
|
||||
hipDeviceAttributeTotalGlobalMem, ///< Global memory available on devicice.
|
||||
hipDeviceAttributeUnifiedAddressing, ///< Cuda only. An unified address space shared with the host.
|
||||
hipDeviceAttributeUuid, ///< Cuda only. Unique ID in 16 byte.
|
||||
hipDeviceAttributeWarpSize, ///< Warp size in threads.
|
||||
hipDeviceAttributeMemoryPoolsSupported, ///< Device supports HIP Stream Ordered Memory Allocator
|
||||
|
||||
hipDeviceAttributeCudaCompatibleEnd = 9999,
|
||||
hipDeviceAttributeAmdSpecificBegin = 10000,
|
||||
|
||||
hipDeviceAttributeClockInstructionRate = hipDeviceAttributeAmdSpecificBegin, ///< Frequency in khz of the timer used by the device-side "clock*"
|
||||
hipDeviceAttributeArch, ///< Device architecture
|
||||
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory PerMultiprocessor.
|
||||
hipDeviceAttributeGcnArch, ///< Device gcn architecture
|
||||
hipDeviceAttributeGcnArchName, ///< Device gcnArch name in 256 bytes
|
||||
hipDeviceAttributeHdpMemFlushCntl, ///< Address of the HDP_MEM_COHERENCY_FLUSH_CNTL register
|
||||
hipDeviceAttributeHdpRegFlushCntl, ///< Address of the HDP_REG_COHERENCY_FLUSH_CNTL register
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc, ///< Supports cooperative launch on multiple
|
||||
///< devices with unmatched functions
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim, ///< Supports cooperative launch on multiple
|
||||
///< devices with unmatched grid dimensions
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim, ///< Supports cooperative launch on multiple
|
||||
///< devices with unmatched block dimensions
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem, ///< Supports cooperative launch on multiple
|
||||
///< devices with unmatched shared memories
|
||||
hipDeviceAttributeIsLargeBar, ///< Whether it is LargeBar
|
||||
hipDeviceAttributeAsicRevision, ///< Revision of the GPU in this device
|
||||
hipDeviceAttributeCanUseStreamWaitValue, ///< '1' if Device supports hipStreamWaitValue32() and
|
||||
///< hipStreamWaitValue64(), '0' otherwise.
|
||||
hipDeviceAttributeImageSupport, ///< '1' if Device supports image, '0' otherwise.
|
||||
hipDeviceAttributePhysicalMultiProcessorCount, ///< All available physical compute
|
||||
///< units for the device
|
||||
hipDeviceAttributeFineGrainSupport, ///< '1' if Device supports fine grain, '0' otherwise
|
||||
|
||||
hipDeviceAttributeAmdSpecificEnd = 19999,
|
||||
hipDeviceAttributeVendorSpecificBegin = 20000,
|
||||
// Extended attributes for vendors
|
||||
} hipDeviceAttribute_t;
|
||||
|
||||
#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01)
|
||||
#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02)
|
||||
#define HIP_LAUNCH_PARAM_END ((void*)0x03)
|
||||
// API-visible structures
|
||||
typedef struct ihipCtx_t* hipCtx_t;
|
||||
|
||||
// Note many APIs also use integer deviceIds as an alternative to the device pointer:
|
||||
typedef int hipDevice_t;
|
||||
|
||||
typedef enum hipDeviceP2PAttr {
|
||||
hipDevP2PAttrPerformanceRank = 0,
|
||||
hipDevP2PAttrAccessSupported,
|
||||
hipDevP2PAttrNativeAtomicSupported,
|
||||
hipDevP2PAttrHipArrayAccessSupported
|
||||
} hipDeviceP2PAttr;
|
||||
|
||||
typedef struct ihipStream_t* hipStream_t;
|
||||
|
||||
#define hipIpcMemLazyEnablePeerAccess 0
|
||||
|
||||
#define HIP_IPC_HANDLE_SIZE 64
|
||||
|
||||
typedef struct hipIpcMemHandle_st {
|
||||
char reserved[HIP_IPC_HANDLE_SIZE];
|
||||
} hipIpcMemHandle_t;
|
||||
|
||||
typedef struct hipIpcEventHandle_st {
|
||||
char reserved[HIP_IPC_HANDLE_SIZE];
|
||||
} hipIpcEventHandle_t;
|
||||
|
||||
typedef struct ihipModule_t* hipModule_t;
|
||||
|
||||
typedef struct ihipModuleSymbol_t* hipFunction_t;
|
||||
|
||||
typedef struct hipFuncAttributes {
|
||||
@@ -150,91 +317,8 @@ typedef struct hipFuncAttributes {
|
||||
int ptxVersion;
|
||||
size_t sharedSizeBytes;
|
||||
} hipFuncAttributes;
|
||||
|
||||
typedef struct ihipEvent_t* hipEvent_t;
|
||||
|
||||
/*
|
||||
* @brief hipDeviceAttribute_t
|
||||
* @enum
|
||||
* @ingroup Enumerations
|
||||
*/
|
||||
typedef enum hipDeviceAttribute_t {
|
||||
hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per block.
|
||||
hipDeviceAttributeMaxBlockDimX, ///< Maximum x-dimension of a block.
|
||||
hipDeviceAttributeMaxBlockDimY, ///< Maximum y-dimension of a block.
|
||||
hipDeviceAttributeMaxBlockDimZ, ///< Maximum z-dimension of a block.
|
||||
hipDeviceAttributeMaxGridDimX, ///< Maximum x-dimension of a grid.
|
||||
hipDeviceAttributeMaxGridDimY, ///< Maximum y-dimension of a grid.
|
||||
hipDeviceAttributeMaxGridDimZ, ///< Maximum z-dimension of a grid.
|
||||
hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory available per block in
|
||||
///< bytes.
|
||||
hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes.
|
||||
hipDeviceAttributeWarpSize, ///< Warp size in threads.
|
||||
hipDeviceAttributeMaxRegistersPerBlock, ///< Maximum number of 32-bit registers available to a
|
||||
///< thread block. This number is shared by all thread
|
||||
///< blocks simultaneously resident on a
|
||||
///< multiprocessor.
|
||||
hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz.
|
||||
hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in kilohertz.
|
||||
hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits.
|
||||
hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the device.
|
||||
hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in.
|
||||
hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device doesn't have L2
|
||||
///< cache.
|
||||
hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads per
|
||||
///< multiprocessor.
|
||||
hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability version number.
|
||||
hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability version number.
|
||||
hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple kernels
|
||||
///< concurrently.
|
||||
hipDeviceAttributePciBusId, ///< PCI Bus ID.
|
||||
hipDeviceAttributePciDeviceId, ///< PCI Device ID.
|
||||
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory Per
|
||||
///< Multiprocessor.
|
||||
hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices.
|
||||
hipDeviceAttributeIntegrated, ///< iGPU
|
||||
hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch
|
||||
hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative launch on multiple devices
|
||||
hipDeviceAttributeMaxTexture1DWidth, ///< Maximum number of elements in 1D images
|
||||
hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D images in image elements
|
||||
hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension height of 2D images in image elements
|
||||
hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D images in image elements
|
||||
hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimensions height of 3D images in image elements
|
||||
hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimensions depth of 3D images in image elements
|
||||
|
||||
hipDeviceAttributeHdpMemFlushCntl, ///< Address of the HDP_MEM_COHERENCY_FLUSH_CNTL register
|
||||
hipDeviceAttributeHdpRegFlushCntl, ///< Address of the HDP_REG_COHERENCY_FLUSH_CNTL register
|
||||
|
||||
hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory copies
|
||||
hipDeviceAttributeTextureAlignment, ///<Alignment requirement for textures
|
||||
hipDeviceAttributeTexturePitchAlignment, ///<Pitch alignment requirement for 2D texture references bound to pitched memory;
|
||||
hipDeviceAttributeKernelExecTimeout, ///<Run time limit for kernels executed on the device
|
||||
hipDeviceAttributeCanMapHostMemory, ///<Device can map host memory into device address space
|
||||
hipDeviceAttributeEccEnabled, ///<Device has ECC support enabled
|
||||
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc, ///< Supports cooperative launch on multiple
|
||||
///devices with unmatched functions
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim, ///< Supports cooperative launch on multiple
|
||||
///devices with unmatched grid dimensions
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim, ///< Supports cooperative launch on multiple
|
||||
///devices with unmatched block dimensions
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem, ///< Supports cooperative launch on multiple
|
||||
///devices with unmatched shared memories
|
||||
hipDeviceAttributeAsicRevision, ///< Revision of the GPU in this device
|
||||
hipDeviceAttributeManagedMemory, ///< Device supports allocating managed memory on this system
|
||||
hipDeviceAttributeDirectManagedMemAccessFromHost, ///< Host can directly access managed memory on
|
||||
/// the device without migration
|
||||
hipDeviceAttributeConcurrentManagedAccess, ///< Device can coherently access managed memory
|
||||
/// concurrently with the CPU
|
||||
hipDeviceAttributePageableMemoryAccess, ///< Device supports coherently accessing pageable memory
|
||||
/// without calling hipHostRegister on it
|
||||
hipDeviceAttributePageableMemoryAccessUsesHostPageTables, ///< Device accesses pageable memory via
|
||||
/// the host's page tables
|
||||
hipDeviceAttributeCanUseStreamWaitValue ///< '1' if Device supports hipStreamWaitValue32() and
|
||||
///< hipStreamWaitValue64() , '0' otherwise.
|
||||
|
||||
} hipDeviceAttribute_t;
|
||||
|
||||
typedef void* hipDeviceptr_t;
|
||||
|
||||
/*
|
||||
@@ -262,7 +346,6 @@ typedef enum hipJitOption {
|
||||
hipJitOptionFastCompile,
|
||||
hipJitOptionNumOptions
|
||||
} hipJitOption;
|
||||
|
||||
/**
|
||||
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
|
||||
*/
|
||||
@@ -271,7 +354,6 @@ typedef enum hipFuncAttribute {
|
||||
hipFuncAttributePreferredSharedMemoryCarveout = 9,
|
||||
hipFuncAttributeMax
|
||||
} hipFuncAttribute;
|
||||
|
||||
/**
|
||||
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
|
||||
*/
|
||||
@@ -282,7 +364,4 @@ typedef enum hipFuncCache_t {
|
||||
hipFuncCachePreferEqual, ///< prefer equal size L1 cache and shared memory
|
||||
} hipFuncCache_t;
|
||||
|
||||
|
||||
#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01)
|
||||
#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02)
|
||||
#define HIP_LAUNCH_PARAM_END ((void*)0x03)
|
||||
#endif
|
||||
|
@@ -45,6 +45,7 @@ public:
|
||||
value *get_int64(uint64_t val);
|
||||
value *get_float16(float val);
|
||||
value *get_float32(float val);
|
||||
value *get_float64(float val);
|
||||
value *get_range(int32_t lo, int32_t hi);
|
||||
// Types
|
||||
type *get_void_ty();
|
||||
|
@@ -196,7 +196,7 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
tensor_core_type_ = get_mma_type(dot);
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
if(tgt->as_nvidia()->sm() < 80){
|
||||
if(tgt->as_nvidia() && tgt->as_nvidia()->sm() < 80){
|
||||
fpw_ = {2, 2, 1};
|
||||
auto ord_a = layout_a->get_order();
|
||||
auto ord_b = layout_b->get_order();
|
||||
|
@@ -79,11 +79,13 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
|
||||
ir::module& ir, llvm::LLVMContext& ctx, codegen::target* target,
|
||||
int num_warps, int num_stages, int& shared_static,
|
||||
const ExternLibMap& extern_lib_map) {
|
||||
std::cout << "pass.cc: add_passes_to_emit_bin" << std::endl;
|
||||
// generate llvm code
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
||||
// optimizations
|
||||
bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
|
||||
// bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
|
||||
bool has_sm80 = false;
|
||||
// create passes
|
||||
codegen::analysis::align align;
|
||||
codegen::transform::inliner inliner;
|
||||
@@ -149,13 +151,6 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
|
||||
// ir.print(std::cout);
|
||||
isel.visit(ir, *llvm);
|
||||
shared_static = allocation.allocated_size();
|
||||
if (target->as_nvidia() && target->as_nvidia()->sm() < 70) {
|
||||
// sm < 70 (Pascal) has little shared memory resource.
|
||||
// Instead of having "Error: Invalid argument" on launching a kernel, let's throw an error here.
|
||||
if (shared_static >= 65536) {
|
||||
throw std::runtime_error("Device does not support shared memory of " + std::to_string(shared_static) + "bytes");
|
||||
}
|
||||
}
|
||||
|
||||
if (isel.get_extern_lib_map().size() > 0) {
|
||||
// If there's any extern lib calls,
|
||||
|
@@ -16,7 +16,13 @@
|
||||
#include "triton/ir/utils.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
#ifdef USE_ROCM
|
||||
#include "llvm/IR/IntrinsicsAMDGPU.h"
|
||||
#else
|
||||
#include "llvm/IR/IntrinsicsNVPTX.h"
|
||||
#endif
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Attributes.h"
|
||||
#include "llvm/IR/InlineAsm.h"
|
||||
@@ -91,6 +97,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
#define f16_ty builder_->getHalfTy()
|
||||
#define bf16_ty builder_->getInt16Ty()
|
||||
#define f32_ty builder_->getFloatTy()
|
||||
#define f64_ty builder_->getDoubleTy()
|
||||
#define i1_ty builder_->getInt1Ty()
|
||||
#define i8_ty builder_->getInt8Ty()
|
||||
#define i16_ty builder_->getInt16Ty()
|
||||
@@ -410,48 +417,44 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
Value *lhs = vals_[x->get_operand(0)][idx];
|
||||
Value *rhs = vals_[x->get_operand(1)][idx];
|
||||
// manually select bf16 bin op
|
||||
if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) {
|
||||
assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty());
|
||||
if (x->get_op() == tt::FAdd) { // a + b = a * 1.0 + b
|
||||
if (x->get_op() == tt::FAdd) {
|
||||
InlineAsm *bf16_add_asm =
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
"{ .reg .b16 c; \n\t"
|
||||
" mov.b16 c, 0x3f80U; \n\t" // 1.0
|
||||
" fma.rn.bf16 $0, $1, c, $2; } \n\t",
|
||||
"=h,h,h", false);
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
"v_lshlrev_b32 $1, 16, $1 "
|
||||
"v_lshlrev_b32 $2, 16, $2 "
|
||||
"v_add_f32 $0, $1, $2 "
|
||||
"v_lshrrev_b32 $0, 16, $0 ",
|
||||
"=v,v,v", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs});
|
||||
} else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a
|
||||
InlineAsm *bf16_sub_asm =
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
" { .reg .b16 c; \n\t"
|
||||
" mov.b16 c, 0xbf80U; \n\t" // -1.0
|
||||
" fma.rn.bf16 $0, $2, c, $1;} \n\t",
|
||||
"=h,h,h", false);
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
"v_lshlrev_b32 $1, 16, $1 "
|
||||
"v_lshlrev_b32 $2, 16, $2 "
|
||||
"v_sub_f32 $0, $1, $2 "
|
||||
"v_lshrrev_b32 $0, 16, $0 ",
|
||||
"=v,v,v", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs});
|
||||
} else if (x->get_op() == tt::FMul) { // a * b = a*b + 0
|
||||
InlineAsm *bf16_mul_asm =
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
" { .reg .b16 c; \n\t"
|
||||
" mov.b16 c, 0x8000U; \n\t" // 0.0
|
||||
" fma.rn.bf16 $0, $1, $2, c;} \n\t",
|
||||
"=h,h,h", false);
|
||||
"v_lshlrev_b32 $1, 16, $1 "
|
||||
"v_lshlrev_b32 $2, 16, $2 "
|
||||
"v_mul_f32 $0, $1, $2 "
|
||||
"v_lshrrev_b32 $0, 16, $0 ",
|
||||
"=v,v,v", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs});
|
||||
} else
|
||||
throw std::runtime_error("invalid bin op for bf16");
|
||||
} else { // not bf16
|
||||
}
|
||||
else {
|
||||
auto op = cvt(x->get_op());
|
||||
if(op == ll::Add)
|
||||
vals_[x][idx] = add(lhs, rhs);
|
||||
else if(op == ll::Mul)
|
||||
vals_[x][idx] = mul(lhs, rhs);
|
||||
else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() &&
|
||||
x->get_type()->get_scalar_ty()->is_fp32_ty()){
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false),
|
||||
" div.full.f32 $0, $1, $2;", "=r,r,r", false);
|
||||
vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs});
|
||||
|
||||
}
|
||||
else
|
||||
vals_[x][idx] = bin_op(op, lhs, rhs);
|
||||
}
|
||||
@@ -754,7 +757,7 @@ Value* generator::bf16_to_fp32(Value *in0){
|
||||
}
|
||||
|
||||
Value* generator::fp32_to_bf16(Value *in0){
|
||||
if(tgt_->as_nvidia()->sm() >= 80){
|
||||
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80){
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(bf16_ty, {f32_ty}, false),
|
||||
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
|
||||
return call(ptx, {in0});
|
||||
@@ -1120,6 +1123,22 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
ir::value *op = x->get_pointer_operand();
|
||||
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
|
||||
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// code generation
|
||||
auto idxs = idxs_.at(x);
|
||||
for(size_t i = 0; i <idxs.size(); i += 1){
|
||||
indices_t idx = idxs[i];
|
||||
// pointer value
|
||||
Value *ptr = vals_[op][idx];
|
||||
|
||||
// create load
|
||||
Value *_ret = builder_->CreateLoad(ty, ptr);
|
||||
|
||||
// upload to global vals map
|
||||
vals_[x][idx] = _ret;
|
||||
}
|
||||
#else
|
||||
// compute vector width
|
||||
size_t vec = 1;
|
||||
bool is_mma_first_row = false;
|
||||
@@ -1298,6 +1317,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
for(size_t ii = 0; ii < vec; ii++)
|
||||
vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
@@ -1316,6 +1336,23 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
// operands
|
||||
ir::value *ptr_op = x->get_pointer_operand();
|
||||
ir::value *val_op = x->get_value_operand();
|
||||
#ifdef USE_ROCM
|
||||
auto idxs = idxs_.at(val_op);
|
||||
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
|
||||
|
||||
for (size_t i = 0; i < idxs.size(); i += 1)
|
||||
{
|
||||
auto idx = idxs[i];
|
||||
// pointer
|
||||
Value *ptr = vals_[ptr_op][idx];
|
||||
|
||||
// value
|
||||
Value *val = vals_.at(val_op)[idxs[i]];
|
||||
|
||||
// store value at pointer
|
||||
store(val, ptr);
|
||||
}
|
||||
#else
|
||||
ir::value *msk_op = nullptr;
|
||||
if(auto* msk_st = dynamic_cast<ir::masked_store_inst*>(x))
|
||||
msk_op = msk_st->get_mask_operand();
|
||||
@@ -1431,6 +1468,7 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
args.push_back(policies_.at(x->get_eviction_policy()));
|
||||
call(_asm, args);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) {
|
||||
visit_store_inst(x);
|
||||
@@ -1549,7 +1587,12 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *ex2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::exp2, tys);
|
||||
#else
|
||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
|
||||
#endif
|
||||
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
|
||||
// Value *ex2arg = vals_[x->get_operand(0)][idx];
|
||||
@@ -1563,7 +1606,11 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
void generator::visit_cos_inst(ir::cos_inst* x){
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *cos = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::cos, tys);
|
||||
#else
|
||||
InlineAsm *cos = InlineAsm::get(fn_ty, "cos.approx.f32 $0, $0;", "=f,0", false);
|
||||
#endif
|
||||
for(auto idx: idxs_.at(x)){
|
||||
vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
}
|
||||
@@ -1589,7 +1636,11 @@ void generator::visit_umulhi_inst(ir::umulhi_inst* x){
|
||||
void generator::visit_sin_inst(ir::sin_inst* x){
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *sin = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::sin, tys);
|
||||
#else
|
||||
InlineAsm *sin = InlineAsm::get(fn_ty, "sin.approx.f32 $0, $0;", "=f,0", false);
|
||||
#endif
|
||||
for(auto idx: idxs_.at(x)){
|
||||
vals_[x][idx] = call(sin, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
}
|
||||
@@ -1602,7 +1653,11 @@ void generator::visit_log_inst(ir::log_inst* x){
|
||||
Constant *rcplog2e = ConstantFP::get(f32_ty, 0.6931471805599453);
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *lg2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::log2, tys);
|
||||
#else
|
||||
InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false);
|
||||
#endif
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value *lg2arg = call(lg2, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
vals_[x][idx] = fmul(lg2arg, rcplog2e);
|
||||
@@ -1612,6 +1667,35 @@ void generator::visit_log_inst(ir::log_inst* x){
|
||||
/**
|
||||
* \brief Code Generation for `atomic_cas`
|
||||
*/
|
||||
#if defined(USE_ROCM)
|
||||
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *pred = icmp_eq(tid, i32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
add_barrier();
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
cond_br(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
Value *cas_ptr = vals_[cas->get_operand(0)][{}];
|
||||
Value *cas_cmp = vals_[cas->get_operand(1)][{}];
|
||||
Value *cas_val = vals_[cas->get_operand(2)][{}];
|
||||
Value *old = atomic_cmp_xchg(cas_ptr, cas_cmp, cas_val, MaybeAlign(), AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
|
||||
old = extract_val(old, std::vector<unsigned>{0});
|
||||
Value *atom_ptr;
|
||||
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), "");
|
||||
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
|
||||
store(old, atom_ptr);
|
||||
br(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
add_barrier();
|
||||
vals_[cas][{}] = load(atom_ptr);
|
||||
add_barrier();
|
||||
}
|
||||
#else
|
||||
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
@@ -1646,12 +1730,66 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
vals_[cas][{}] = load(atom_ptr);
|
||||
add_barrier();
|
||||
}
|
||||
#endif // defined(USE_ROCM)
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `atomic_rmw`
|
||||
*/
|
||||
#if defined(USE_ROCM)
|
||||
void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
ir::value* ptr = atom->get_operand(0);
|
||||
if (atom->get_op() == ir::atomic_rmw_op_t::Add ||
|
||||
atom->get_op() == ir::atomic_rmw_op_t::Max ||
|
||||
atom->get_op() == ir::atomic_rmw_op_t::Min ||
|
||||
atom->get_op() == ir::atomic_rmw_op_t::UMax ||
|
||||
atom->get_op() == ir::atomic_rmw_op_t::UMin ||
|
||||
atom->get_op() == ir::atomic_rmw_op_t::Xchg) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *rmw_ptr = vals_[atom->get_operand(0)][{}];
|
||||
Value *rmw_val = vals_[atom->get_operand(1)][{}];
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *pred = icmp_eq(tid, i32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
add_barrier();
|
||||
cond_br(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
AtomicRMWInst::BinOp binop;
|
||||
switch (atom->get_op()) {
|
||||
case ir::atomic_rmw_op_t::Add:
|
||||
binop = AtomicRMWInst::Add;
|
||||
break;
|
||||
case ir::atomic_rmw_op_t::Max:
|
||||
binop = AtomicRMWInst::Max;
|
||||
break;
|
||||
case ir::atomic_rmw_op_t::Min:
|
||||
binop = AtomicRMWInst::Min;
|
||||
break;
|
||||
case ir::atomic_rmw_op_t::UMax:
|
||||
binop = AtomicRMWInst::UMax;
|
||||
break;
|
||||
case ir::atomic_rmw_op_t::UMin:
|
||||
binop = AtomicRMWInst::UMin;
|
||||
break;
|
||||
case ir::atomic_rmw_op_t::Xchg:
|
||||
binop = AtomicRMWInst::Xchg;
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Not supported!");
|
||||
}
|
||||
atomic_rmw(binop, rmw_ptr, rmw_val, MaybeAlign(), AtomicOrdering::Monotonic,
|
||||
SyncScope::System);
|
||||
br(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("Not supported!");
|
||||
}
|
||||
#else // defined(USE_ROCM)
|
||||
void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
ir::value *ptr = atom->get_operand(0);
|
||||
ir::value* val = atom->get_operand(1);
|
||||
ir::value* msk = atom->get_operand(2);
|
||||
|
||||
@@ -1756,6 +1894,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // defined(USE_ROCM)
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `mma.884` (V100)
|
||||
@@ -2834,15 +2973,20 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
||||
size_t red_axis = 1;
|
||||
unsigned NK = A_shapes[red_axis];
|
||||
bool is_outer = NK == 1;
|
||||
|
||||
#ifdef USE_ROCM
|
||||
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
||||
#else
|
||||
bool is_mma = layouts_->get(dot)->to_mma();
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80)
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80)
|
||||
return visit_mma884(dot, A, B, D, NK);
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
|
||||
if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
|
||||
A->get_type()->get_scalar_ty()->is_fp32_ty())
|
||||
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
||||
throw std::runtime_error("dot has invalid operand type");
|
||||
#endif
|
||||
}
|
||||
|
||||
void generator::visit_trans_inst(ir::trans_inst* trans) {
|
||||
@@ -2875,8 +3019,14 @@ Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vec
|
||||
|
||||
inline Value* generator::shfl_sync(Value* acc, int32_t i){
|
||||
Type* ty = acc->getType();
|
||||
#ifdef USE_ROCM
|
||||
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
|
||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
|
||||
#else
|
||||
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
|
||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
|
||||
#endif
|
||||
|
||||
if(ty->getPrimitiveSizeInBits() <= 32)
|
||||
return call(shfl, {acc, i32(i)});
|
||||
acc = bit_cast(acc, vec_ty(f32_ty, 2));
|
||||
@@ -3171,12 +3321,16 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
ir::value *arg = x->get_operand(0);
|
||||
#ifdef USE_ROCM
|
||||
visit_reducend_inst(x, do_acc, neutral);
|
||||
#else
|
||||
bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x);
|
||||
bool is_a100_mma = layouts_->is_a100_mma(x);
|
||||
if (is_coalesced_scanline || is_a100_mma)
|
||||
visit_reducend_inst_fast(x, do_acc, neutral);
|
||||
else
|
||||
visit_reducend_inst(x, do_acc, neutral);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -3645,6 +3799,7 @@ Value *generator::cast_shared_layout_ptr(analysis::data_layout *layout,
|
||||
}
|
||||
|
||||
void generator::visit_function(ir::function* fn) {
|
||||
std::cout << "generator.cc: generator::visit_function:" << std::endl;
|
||||
idxs_.clear();
|
||||
vals_.clear();
|
||||
seen_.clear();
|
||||
@@ -3654,6 +3809,7 @@ void generator::visit_function(ir::function* fn) {
|
||||
|
||||
|
||||
// set attributes
|
||||
std::cout << "\t// set attributes" << std::endl;
|
||||
for(auto attr_pair: fn->attrs()){
|
||||
unsigned id = attr_pair.first;
|
||||
for(ir::attribute attr: attr_pair.second)
|
||||
@@ -3664,19 +3820,24 @@ void generator::visit_function(ir::function* fn) {
|
||||
}
|
||||
}
|
||||
// set metadata
|
||||
std::cout << "\t// set metadata" << std::endl;
|
||||
if(tgt_->is_gpu()){
|
||||
tgt_->set_kernel(*builder_, ctx, mod_, ret);
|
||||
#ifndef USE_ROCM
|
||||
Metadata *md_args[] = {
|
||||
ValueAsMetadata::get(ret),
|
||||
MDString::get(ctx, "maxntidx"),
|
||||
ValueAsMetadata::get(i32(num_warps_*32))
|
||||
};
|
||||
mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
|
||||
#endif
|
||||
}
|
||||
// set arguments
|
||||
std::cout << "\t// set arguments" << std::endl;
|
||||
for(unsigned i = 0; i < fn->args().size(); i++)
|
||||
vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i);
|
||||
// create blocks
|
||||
std::cout << "\t// create blocks" << std::endl;
|
||||
auto blocks = ir::cfg::reverse_post_order(fn);
|
||||
for(ir::basic_block *block: blocks) {
|
||||
BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret);
|
||||
@@ -3684,6 +3845,8 @@ void generator::visit_function(ir::function* fn) {
|
||||
}
|
||||
builder_->SetInsertPoint(bbs_[fn->blocks()[0]]);
|
||||
// create policies
|
||||
#ifndef USE_ROCM
|
||||
std::cout << "\t// create policies" << std::endl;
|
||||
if(tgt_->as_nvidia()->sm() >= 80)
|
||||
for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){
|
||||
std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last";
|
||||
@@ -3691,15 +3854,23 @@ void generator::visit_function(ir::function* fn) {
|
||||
InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false);
|
||||
policies_[evict] = call(iasm);
|
||||
}
|
||||
#endif
|
||||
// initialize layouts
|
||||
std::cout << "\t// initialize layouts" << std::endl;
|
||||
for(auto x: layouts_->get_all()){
|
||||
visit_layout(x.second);
|
||||
}
|
||||
// generate LLVM-IR code
|
||||
std::cout << "\t// generate LLVM-IR code" << std::endl;
|
||||
for(ir::basic_block *block: blocks)
|
||||
visit_basic_block(block);
|
||||
// finalize
|
||||
std::cout << "\t// finalize" << std::endl;
|
||||
finalize_function(fn);
|
||||
|
||||
// verifyFunction
|
||||
std::cout << "\t// verifyFunction" << std::endl;
|
||||
llvm::verifyFunction(*ret);
|
||||
}
|
||||
|
||||
|
||||
@@ -3723,7 +3894,11 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
|
||||
Value *_8 = i32(8);
|
||||
Value *_16 = i32(16);
|
||||
Value *_32 = i32(32);
|
||||
#ifdef USE_ROCM
|
||||
int cc = 1; // generate ir for older CUDA cards
|
||||
#else
|
||||
int cc = tgt_->as_nvidia()->sm();
|
||||
#endif
|
||||
std::vector<Value*> idx_m;
|
||||
std::vector<Value*> idx_n;
|
||||
std::vector<Value*> idx_z;
|
||||
@@ -4114,6 +4289,7 @@ void generator::packed_type(ir::value* i){
|
||||
}
|
||||
|
||||
void generator::visit(ir::module &src, llvm::Module &dst) {
|
||||
std::cout << "generator.cc: generator::visit" << std::endl;
|
||||
mod_ = &dst;
|
||||
ctx_ = &dst.getContext();
|
||||
builder_ = new Builder(*ctx_);
|
||||
|
@@ -15,10 +15,22 @@ namespace codegen{
|
||||
|
||||
// base
|
||||
|
||||
|
||||
nvidia_cu_target* target::as_nvidia() {
|
||||
return dynamic_cast<nvidia_cu_target*>(this);
|
||||
#ifdef USE_ROCM
|
||||
amd_cl_target *target::as_amd()
|
||||
{
|
||||
return dynamic_cast<amd_cl_target *>(this);
|
||||
}
|
||||
amd_cl_target *target::as_nvidia()
|
||||
{
|
||||
return this->as_amd();
|
||||
}
|
||||
#else
|
||||
// causes segfault on ROCM
|
||||
nvidia_cu_target *target::as_nvidia()
|
||||
{
|
||||
return dynamic_cast<nvidia_cu_target *>(this);
|
||||
}
|
||||
#endif
|
||||
|
||||
bool target::is_gpu() const {
|
||||
return is_gpu_;
|
||||
@@ -41,7 +53,8 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un
|
||||
}
|
||||
|
||||
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
throw std::runtime_error("not implemented");
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_waitcnt);
|
||||
return builder.CreateIntrinsic(Intrinsic::amdgcn_s_waitcnt, {}, {builder.getInt32(0)});
|
||||
}
|
||||
|
||||
|
||||
@@ -56,7 +69,50 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
throw std::runtime_error("not implemented on AMD");
|
||||
Function &F = *builder.GetInsertBlock()->getParent();
|
||||
Module *Mod = F.getParent();
|
||||
// We are indexing into this struct, and want to extract the grid_size_*
|
||||
// fields.
|
||||
//
|
||||
// typedef struct hsa_kernel_dispatch_packet_s {
|
||||
// uint16_t header;
|
||||
// uint16_t setup;
|
||||
// uint16_t workgroup_size_x ;
|
||||
// uint16_t workgroup_size_y;
|
||||
// uint16_t workgroup_size_z;
|
||||
// uint16_t reserved0;
|
||||
// uint32_t grid_size_x ;
|
||||
// uint32_t grid_size_y ;
|
||||
// uint32_t grid_size_z;
|
||||
// .....
|
||||
// } hsa_kernel_dispatch_packet_t
|
||||
//
|
||||
Function *DispatchPtrFn =
|
||||
Intrinsic::getDeclaration(Mod, Intrinsic::amdgcn_dispatch_ptr);
|
||||
|
||||
CallInst *DispatchPtr = builder.CreateCall(DispatchPtrFn, {});
|
||||
DispatchPtr->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
|
||||
DispatchPtr->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
|
||||
F.removeFnAttr("amdgpu-no-dispatch-ptr");
|
||||
|
||||
// Size of the dispatch packet struct.
|
||||
DispatchPtr->addDereferenceableAttr(AttributeList::ReturnIndex, 64);
|
||||
|
||||
Type *I32Ty = Type::getInt32Ty(Mod->getContext());
|
||||
// TODO: include AMDGPUAS:: declarations.
|
||||
Value *CastDispatchPtr = builder.CreateBitCast(
|
||||
DispatchPtr, PointerType::get(I32Ty, 4 /*AMDGPUAS::CONSTANT_ADDRESS*/));
|
||||
|
||||
// grid_size_x offset is 3*32bit
|
||||
assert(ax < 3);
|
||||
Value *GEP =
|
||||
builder.CreateConstInBoundsGEP1_64(I32Ty, CastDispatchPtr, ax + 3);
|
||||
LoadInst *Load = builder.CreateAlignedLoad(I32Ty, GEP, Align(4));
|
||||
|
||||
MDNode *MD = MDNode::get(Mod->getContext(), None);
|
||||
Load->setMetadata(LLVMContext::MD_invariant_load, MD);
|
||||
|
||||
return Load; // throw std::runtime_error("not implemented on AMD");
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
@@ -156,7 +212,7 @@ Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsi
|
||||
}
|
||||
|
||||
Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
throw std::runtime_error("not implemented");
|
||||
throw std::runtime_error("not implemented on CPU");
|
||||
}
|
||||
|
||||
|
||||
|
@@ -91,7 +91,7 @@ void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder&
|
||||
if(inst_map.find(inst_op) != inst_map.end())
|
||||
new_inst->set_operand(k, inst_map.at(inst_op));
|
||||
}
|
||||
// handles a ret instruction.
|
||||
// handles a ret instruciton.
|
||||
// instead of returning we need to branch to after the function call
|
||||
if(ir::return_inst* ret = dynamic_cast<ir::return_inst*>(new_inst)) {
|
||||
if(ir::value* ret_val = ret->get_return_value())
|
||||
|
@@ -222,6 +222,7 @@ bool dispatch::hipinit(){
|
||||
return res;
|
||||
}
|
||||
|
||||
#define HIP_DEFINE0(ret, fname) DEFINE0(hipinit, hip_, ret, fname)
|
||||
#define HIP_DEFINE1(ret, fname, t1) DEFINE1(hipinit, hip_, ret, fname, t1)
|
||||
#define HIP_DEFINE2(ret, fname, t1, t2) DEFINE2(hipinit, hip_, ret, fname, t1, t2)
|
||||
#define HIP_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3)
|
||||
@@ -278,7 +279,8 @@ HIP_DEFINE2(hipError_t, hipEventCreate, hipEvent_t *, unsigned int)
|
||||
HIP_DEFINE3(hipError_t, hipEventElapsedTime, float *, hipEvent_t, hipEvent_t)
|
||||
HIP_DEFINE2(hipError_t, hipEventRecord, hipEvent_t, hipStream_t)
|
||||
HIP_DEFINE1(hipError_t, hipEventDestroy, hipEvent_t)
|
||||
|
||||
// error handling
|
||||
HIP_DEFINE0(hipError_t, hipGetLastError)
|
||||
|
||||
/* ------------------- *
|
||||
* COMMON
|
||||
|
@@ -25,6 +25,8 @@
|
||||
#endif
|
||||
#include <memory>
|
||||
#include <regex>
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "triton/driver/dispatch.h"
|
||||
#include "triton/driver/error.h"
|
||||
@@ -57,6 +59,8 @@
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Analysis/TargetLibraryInfo.h"
|
||||
#include "llvm/IR/IntrinsicsAMDGPU.h"
|
||||
#include "llvm/IR/Intrinsics.h"
|
||||
// end AMD stuff
|
||||
|
||||
extern "C"
|
||||
@@ -67,6 +71,24 @@ extern "C"
|
||||
int setupterm(char *term, int fildes, int *errret) { return 0; }
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::string gen_random(const int len) {
|
||||
static const char alphanum[] =
|
||||
"0123456789"
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"abcdefghijklmnopqrstuvwxyz";
|
||||
std::string tmp_s;
|
||||
tmp_s.reserve(len);
|
||||
|
||||
for (int i = 0; i < len; ++i) {
|
||||
tmp_s += alphanum[rand() % (sizeof(alphanum) - 1)];
|
||||
}
|
||||
return tmp_s;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
@@ -266,20 +288,24 @@ namespace triton
|
||||
/* ------------------------ */
|
||||
// HIP //
|
||||
/* ------------------------ */
|
||||
|
||||
std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc)
|
||||
std::tuple<std::string, std::string> llir_to_amdgcn(llvm::Module *module, const std::string &_proc)
|
||||
{
|
||||
std::cout << "llvm.cc: llir_to_amdgcn:" << std::endl;
|
||||
init_llvm();
|
||||
|
||||
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
|
||||
// create
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
std::string triple = "amdgcn-amd-amdhsa";
|
||||
std::string layout = "";
|
||||
std::string features;
|
||||
std::string proc = "gfx908";
|
||||
std::string features = "+sramecc,-xnack";
|
||||
std::string proc = _proc;
|
||||
// name kernel
|
||||
auto in_time_t = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
|
||||
std::stringstream cur_time;
|
||||
cur_time << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d--%I-%M-%S");
|
||||
std::string kernel_name = module->getModuleIdentifier() + "_" + cur_time.str() + "_" + gen_random(12);
|
||||
// verify and store llvm
|
||||
llvm::legacy::PassManager pm;
|
||||
pm.add(llvm::createVerifierPass());
|
||||
@@ -295,7 +321,7 @@ namespace triton
|
||||
opt.NoNaNsFPMath = true;
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
|
||||
llvm::Reloc::PIC_, llvm::None,
|
||||
llvm::CodeGenOpt::Aggressive);
|
||||
llvm::CodeGenOpt::None);
|
||||
// set data layout
|
||||
if (layout.empty())
|
||||
module->setDataLayout(machine->createDataLayout());
|
||||
@@ -308,11 +334,10 @@ namespace triton
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
|
||||
// create dump files
|
||||
std::string module_name = module->getModuleIdentifier();
|
||||
std::error_code ec;
|
||||
|
||||
// Save GCN ISA binary.
|
||||
std::string isabin_path = std::string("/tmp/") + module_name + std::string(".o");
|
||||
std::string isabin_path = std::string("/tmp/") + kernel_name + std::string(".o");
|
||||
std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
|
||||
new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
|
||||
if (ec)
|
||||
@@ -323,15 +348,17 @@ namespace triton
|
||||
// emit
|
||||
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
|
||||
pass.run(*module);
|
||||
|
||||
// Save GCN ISA.
|
||||
std::string amdgcn_path = std::string("/tmp/") + module_name + std::string(".gcn");
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
std::ofstream amdgcn(amdgcn_path);
|
||||
amdgcn << result;
|
||||
amdgcn.close();
|
||||
llvm::SmallVector<char, 0> debugBuffer;
|
||||
llvm::legacy::PassManager debugPass;
|
||||
llvm::raw_svector_ostream debugStream(debugBuffer);
|
||||
machine->addPassesToEmitFile(debugPass, debugStream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile); // TODO:cause segfault on REM ops also cause @llvm.amdgcn.if bug
|
||||
debugPass.run(*module);
|
||||
std::string amdgcn(debugBuffer.begin(), debugBuffer.end());
|
||||
|
||||
// generate HASCO file
|
||||
std::string hsaco_path = std::string("/tmp/") + module_name + std::string(".hsaco");
|
||||
std::string hsaco_path = std::string("/tmp/") + kernel_name + std::string(".hsaco");
|
||||
std::string error_message;
|
||||
int lld_result =
|
||||
llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld",
|
||||
@@ -344,13 +371,14 @@ namespace triton
|
||||
std::cout << lld_result << std::endl;
|
||||
}
|
||||
|
||||
return hsaco_path;
|
||||
return std::make_tuple(amdgcn, hsaco_path);
|
||||
}
|
||||
|
||||
hipModule_t amdgpu_to_hipmodule(const std::string &path)
|
||||
hipModule_t amdgpu_to_hipmodule(const std::string &hsaco_path)
|
||||
{
|
||||
std::cout << "llvm.cc: amdgpu_to_hipmodule:" << std::endl;
|
||||
// Read HSACO.
|
||||
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
|
||||
std::ifstream hsaco_file(hsaco_path, std::ios::binary | std::ios::ate);
|
||||
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
|
||||
|
||||
std::vector<unsigned char> hsaco(hsaco_file_size);
|
||||
|
@@ -60,6 +60,9 @@ value *builder::get_float16(float val)
|
||||
value *builder::get_float32(float val)
|
||||
{ return constant_fp::get(type::get_fp32_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_float64(float val)
|
||||
{ return constant_fp::get(type::get_fp64_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_range(int32_t _lo, int32_t _hi) {
|
||||
constant_int* lo = static_cast<constant_int*>(get_int32(_lo));
|
||||
constant_int* hi = static_cast<constant_int*>(get_int32(_hi));
|
||||
|
@@ -13,6 +13,7 @@ from typing import NamedTuple
|
||||
|
||||
from setuptools import Extension, setup
|
||||
from setuptools.command.build_ext import build_ext
|
||||
import torch
|
||||
|
||||
|
||||
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
|
||||
@@ -32,7 +33,8 @@ def get_build_type():
|
||||
def use_system_llvm():
|
||||
if platform.system() == "Windows":
|
||||
return True
|
||||
versions = ['-11.0', '-11', '-11-64']
|
||||
# versions = ['-11.0', '-11', '-11-64']
|
||||
versions = ['-13.0', '-13', '-13-64']
|
||||
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
||||
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
||||
return any(p is not None for p in paths)
|
||||
@@ -51,9 +53,9 @@ def get_thirdparty_packages(triton_cache_path):
|
||||
Package("pybind11", "pybind11-2.10.0", "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz", "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "")
|
||||
]
|
||||
if not use_system_llvm():
|
||||
# download LLVM if no suitable system LLVM is installed
|
||||
# donwload LLVM if no suitable system LLVM is installed
|
||||
packages.append(
|
||||
Package("llvm", "clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR")
|
||||
Package("llvm", "clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-13.0.0/clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR")
|
||||
)
|
||||
|
||||
thirdparty_cmake_args = []
|
||||
@@ -144,6 +146,9 @@ class CMakeBuild(build_ext):
|
||||
build_args += ["--", '-j' + str(2 * multiprocessing.cpu_count())]
|
||||
|
||||
env = os.environ.copy()
|
||||
|
||||
if torch.version.hip is not None:
|
||||
env["TRITON_USE_ROCM"] = "ON"
|
||||
subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=self.build_temp, env=env)
|
||||
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
|
||||
|
||||
|
@@ -232,7 +232,7 @@ ir::value *store(ir::value *ptr, ir::value *val, std::optional<ir::value *> _mas
|
||||
----------------------------------------------*/
|
||||
std::string dot_docstr = R"pbdoc(
|
||||
Returns the matrix product of two blocks.
|
||||
The two blocks must be two dimensions and have compatible inner dimensions.
|
||||
The two blocks must be two dimensionals and have compatible inner dimensions.
|
||||
|
||||
:param input: The first block to be multiplied.
|
||||
:type input: 2D block of scalar-type in {`float16`, `float32`}
|
||||
|
@@ -100,7 +100,9 @@ void hip_enqueue(uint64_t stream, uint64_t kernel,
|
||||
drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2,
|
||||
block_0, block_1, block_2,
|
||||
shared_mem, (hipStream_t)stream, nullptr, config);
|
||||
|
||||
#ifdef DEBUG_ROCM
|
||||
drv::dispatch::hipGetLastError();
|
||||
#endif
|
||||
}
|
||||
|
||||
long pow2_divisor(long N){
|
||||
@@ -435,7 +437,7 @@ typedef std::map<std::string, py::object> asm_map_t;
|
||||
// ---------------------------------------
|
||||
|
||||
void init_triton_codegen(py::module &&m) {
|
||||
m.def("compile_ttir",
|
||||
m.def("compile_ttir_to_ptx",
|
||||
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
|
||||
std::ostringstream ttir;
|
||||
int n_shared_bytes;
|
||||
@@ -490,11 +492,64 @@ void init_triton_codegen(py::module &&m) {
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
|
||||
m.def("compile_ttir_to_amdgcn",
|
||||
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, const std::string& gfx_arch) {
|
||||
std::ostringstream ttir;
|
||||
int n_shared_bytes;
|
||||
std::string tmp;
|
||||
std::string amdgcn;
|
||||
std::string hsaco_path;
|
||||
std::string name;
|
||||
{
|
||||
std::cout << "triton.cc: compile_ttir_to_amdgcn:" << std::endl;
|
||||
// Scope where the GIL is released
|
||||
py::gil_scoped_release allow_threads;
|
||||
name = ir.get_function_list()[0]->get_name();
|
||||
ir.print(ttir);
|
||||
llvm::LLVMContext ctx;
|
||||
// construct extern lib map
|
||||
triton::codegen::ExternLibMap extern_lib_map;
|
||||
for (auto item : extern_libs) {
|
||||
auto name = item.first.cast<std::string>();
|
||||
auto path = item.second.cast<std::string>();
|
||||
extern_lib_map.emplace(
|
||||
name, triton::codegen::create_extern_lib(name, path));
|
||||
}
|
||||
int version;
|
||||
// std::string ptxas_path = drv::path_to_ptxas(version);
|
||||
// Triton-IR -> AMDGCN LLVM-IR
|
||||
std::cout << "ttir:" << std::endl;
|
||||
std::cout << "\t" << ttir.str() << std::endl;
|
||||
triton::codegen::amd_cl_target target;
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(
|
||||
ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
|
||||
llvm::raw_string_ostream llir(tmp);
|
||||
llir << *llvm;
|
||||
std::cout << "llir:" << std::endl;
|
||||
std::cout << "\t" << llir.str() << std::endl;
|
||||
llir.flush();
|
||||
// LLVM-IR -> AMDGPU
|
||||
std::tuple<std::string, std::string> amdgpu = drv::llir_to_amdgcn(llvm.get(), gfx_arch);
|
||||
amdgcn = std::get<0>(amdgpu);
|
||||
hsaco_path = std::get<1>(amdgpu);
|
||||
std::cout << "amdgcn:" << std::endl;
|
||||
std::cout << "\t" << amdgcn << std::endl;
|
||||
}
|
||||
asm_map_t asm_map;
|
||||
asm_map["ttir"] = py::cast(ttir.str());
|
||||
asm_map["llir"] = py::cast(tmp);
|
||||
asm_map["amdgcn"] = py::cast(amdgcn);
|
||||
asm_map["hsaco_path"] = py::cast(hsaco_path);
|
||||
|
||||
return std::make_tuple(name, asm_map, n_shared_bytes);
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
|
||||
|
||||
// ---------------------------------------
|
||||
// Load provided assembly code into driver
|
||||
// ---------------------------------------
|
||||
m.def("load_binary", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
|
||||
m.def("load_binary_cubin", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
|
||||
py::gil_scoped_release allow_threads;
|
||||
// create driver handles
|
||||
CUfunction fun;
|
||||
@@ -521,6 +576,47 @@ void init_triton_codegen(py::module &&m) {
|
||||
},
|
||||
py::return_value_policy::take_ownership
|
||||
);
|
||||
|
||||
// ---------------------------------------
|
||||
// Load provided assembly code into driver
|
||||
// ---------------------------------------
|
||||
m.def("load_binary_hsaco", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
|
||||
std::cout << "triton.cc: load_binary_hsaco:" << std::endl;
|
||||
std::cout << "\tname:" << name << std::endl;
|
||||
std::cout << "\tdata:" << data << std::endl;
|
||||
std::cout << "\tn_shared_bytes:" << n_shared_bytes << std::endl;
|
||||
std::cout << "\tdevice:" << device << std::endl;
|
||||
py::gil_scoped_release allow_threads;
|
||||
// create driver handles
|
||||
std::cout << "\t" << "// create driver handles" << std::endl;
|
||||
hipFunction_t fun;
|
||||
hipModule_t mod = drv::amdgpu_to_hipmodule(data);
|
||||
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
||||
// get allocated registers and spilled registers from the function
|
||||
std::cout << "\t" << "// get allocated registers and spilled registers from the function" << std::endl;
|
||||
int n_regs = 0;
|
||||
int n_spills = 0;
|
||||
hipFuncAttributes attr;
|
||||
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||
n_regs = attr.numRegs;
|
||||
n_spills = attr.localSizeBytes / 4;
|
||||
// set dynamic shared memory if necessary
|
||||
std::cout << "\t" << "// set dynamic shared memory if necessary" << std::endl;
|
||||
int shared_optin;
|
||||
// drv::dispatch::hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device);
|
||||
if(n_shared_bytes > 49152 && shared_optin > 49152){
|
||||
// drv::dispatch::hipFuncSetCacheConfig(fun, hipFuncCachePreferShared);
|
||||
int shared_total, shared_static;
|
||||
// drv::dispatch::hipDeviceGetAttribute(&shared_total, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device);
|
||||
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||
shared_total = attr.sharedSizeBytes;
|
||||
// drv::dispatch::hipFuncSetAttribute(fun, hipFuncAttributeMaxDynamicSharedMemorySize, shared_optin - shared_static);
|
||||
}
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills);
|
||||
},
|
||||
py::return_value_policy::take_ownership
|
||||
);
|
||||
|
||||
|
||||
struct InstanceDescriptor
|
||||
|
@@ -128,7 +128,7 @@ elementwise_data = {
|
||||
1024 * 16: 0.0219,
|
||||
1024 * 64: 0.0791,
|
||||
1024 * 256: 0.243,
|
||||
1024 * 1024: 0.530,
|
||||
1024 * 1024: 0.534,
|
||||
1024 * 4096: 0.796,
|
||||
1024 * 16384: 0.905,
|
||||
1024 * 65536: 0.939,
|
||||
|
@@ -104,9 +104,12 @@ def check_type_supported(dtype):
|
||||
'''
|
||||
skip test if dtype is not supported on the current device
|
||||
'''
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
if torch.version.hip is not None:
|
||||
pass
|
||||
else:
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
|
||||
@@ -123,6 +126,9 @@ def test_empty_kernel(dtype_x, device='cuda'):
|
||||
|
||||
# generic test functions
|
||||
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
||||
if torch.version.hip is not None:
|
||||
if dtype_x == "bfloat16":
|
||||
pytest.skip("unary op with bfloat is not supported on AMDGPU")
|
||||
check_type_supported(dtype_x) # early return if dtype_x is not supported
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@@ -230,7 +236,6 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
('int64', 'float32'),
|
||||
('int64', 'float64'),
|
||||
('uint16', 'bfloat16'),
|
||||
('uint16', 'float16'),
|
||||
('uint16', 'float32'),
|
||||
('uint32', 'bfloat16'),
|
||||
('uint32', 'float16'),
|
||||
@@ -253,8 +258,12 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
for dtype_y in dtypes_with_bfloat16
|
||||
])
|
||||
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
if torch.version.hip is not None:
|
||||
if dtype_x == "bfloat16" and dtype_y == "bfloat16" :
|
||||
pytest.skip("binary op with bfloat is not supported on AMDGPU")
|
||||
|
||||
expr = f' x {op} y'
|
||||
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
|
||||
if op == '%' and (dtype_x in dtypes and dtype_y in dtypes):
|
||||
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
|
||||
numpy_expr = 'np.fmod(x, y)'
|
||||
elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', 'bfloat16'):
|
||||
@@ -605,10 +614,10 @@ def test_tuples():
|
||||
]
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 70:
|
||||
if dtype_x_str == 'float16':
|
||||
pytest.skip("Only test atomic float16 ops on devices with sm >= 70")
|
||||
if torch.version.hip is not None:
|
||||
# if dtype_x_str in ["uint32","int32","float32"]:
|
||||
pytest.skip(f"test_atomic_rmw[{dtype_x_str}] currently has segfaults on ROCM")
|
||||
|
||||
n_programs = 5
|
||||
|
||||
# triton kernel
|
||||
@@ -675,6 +684,8 @@ def test_tensor_atomic_rmw(axis, device="cuda"):
|
||||
|
||||
|
||||
def test_atomic_cas():
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip(f"test_atomic_cas currently has segfaults on ROCM")
|
||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||
@triton.jit
|
||||
def change_value(Lock):
|
||||
@@ -782,6 +793,8 @@ def test_store_bool():
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_f8_xf16_roundtrip(dtype):
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip(f"test_f8_xf16_roundtrip currently has segfaults on ROCM")
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
check_type_supported(dtype)
|
||||
|
||||
@@ -808,6 +821,8 @@ def test_f8_xf16_roundtrip(dtype):
|
||||
|
||||
|
||||
def test_f16_to_f8_rounding():
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip(f"test_atomic_cas currently has segfaults on ROCM")
|
||||
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||
error is the minimum over all float8.
|
||||
Or the same explanation a bit mathier:
|
||||
@@ -876,6 +891,9 @@ def test_f16_to_f8_rounding():
|
||||
for shape in [32, 64, 128, 512]])
|
||||
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
if torch.version.hip is not None:
|
||||
if dtype_str in ["int8", "int16", "uint8", "uint16"]:
|
||||
pytest.skip(f"test_reduce1d[{dtype_str}] skipped on ROCM")
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -934,6 +952,10 @@ reduce_configs2 = [
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
|
||||
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
if torch.version.hip is not None:
|
||||
if dtype_str in ["int8", "int16", "uint8", "uint16"]:
|
||||
pytest.skip(f"test_reduce2d[{dtype_str}] skipped on ROCM")
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||
@@ -1025,13 +1047,18 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
|
||||
# parse ptx to make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
ptx = pgm_contiguous.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
if torch.version.hip is None:
|
||||
# parse ptx to make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
ptx = pgm_contiguous.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
else:
|
||||
# TODO add rocm gcn assert
|
||||
pass
|
||||
|
||||
# ---------------
|
||||
# test dot
|
||||
@@ -1045,14 +1072,15 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
for dtype in ['float16']
|
||||
if not (allow_tf32 and (dtype in ['float16']))])
|
||||
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 70:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
if cc < 80:
|
||||
if dtype == 'int8':
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
if torch.version.hip is not None:
|
||||
pass
|
||||
else:
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80:
|
||||
if dtype == 'int8':
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
|
||||
M, N, K = 128, 128, 64
|
||||
num_warps = 8
|
||||
@@ -1147,15 +1175,18 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32':
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
if torch.version.hip is not None:
|
||||
pass
|
||||
else:
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32':
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
def test_dot_without_load():
|
||||
@@ -1233,10 +1264,8 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
|
||||
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 70:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip(f"test_masked_load_shared_memory currently has segfaults on ROCM")
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
M = 32
|
||||
@@ -1296,16 +1325,20 @@ def test_load_cache_modifier(cache):
|
||||
tl.store(dst + offsets, x)
|
||||
|
||||
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||
ptx = pgm.asm['ptx']
|
||||
if cache == '':
|
||||
assert 'ld.global.ca' not in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if cache == '.cg':
|
||||
assert 'ld.global.cg' in ptx
|
||||
assert 'ld.global.ca' not in ptx
|
||||
if cache == '.ca':
|
||||
assert 'ld.global.ca' in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if torch.version.hip is None:
|
||||
ptx = pgm.asm['ptx']
|
||||
if cache == '':
|
||||
assert 'ld.global.ca' not in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if cache == '.cg':
|
||||
assert 'ld.global.cg' in ptx
|
||||
assert 'ld.global.ca' not in ptx
|
||||
if cache == '.ca':
|
||||
assert 'ld.global.ca' in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
else:
|
||||
# TODO add rocm gcn assert
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
|
||||
@@ -1319,11 +1352,15 @@ def test_vectorization(N):
|
||||
x = tl.load(src + offsets, mask=offsets < N)
|
||||
tl.store(dst + offsets, x, mask=offsets < N)
|
||||
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
||||
ptx = pgm.asm["ptx"]
|
||||
if N % 16 == 0:
|
||||
assert "ld.global.v4.b32" in ptx
|
||||
if torch.version.hip is None:
|
||||
ptx = pgm.asm["ptx"]
|
||||
if N % 16 == 0:
|
||||
assert "ld.global.v4.b32" in ptx
|
||||
else:
|
||||
assert "ld.global.b32" in ptx
|
||||
else:
|
||||
assert "ld.global.b32" in ptx
|
||||
#TODO add rocm assert
|
||||
pass
|
||||
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||
# ---------------
|
||||
# test store
|
||||
@@ -1557,7 +1594,8 @@ def test_num_warps_pow2():
|
||||
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||
('float64', 'libdevice.norm4d', '')])
|
||||
def test_libdevice_tensor(dtype_str, expr, lib_path):
|
||||
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip(f"test_libdevice_tensor currently has segfaults on ROCM")
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
@@ -1597,6 +1635,8 @@ def test_libdevice_tensor(dtype_str, expr, lib_path):
|
||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
[('float32', 'libdevice.pow', '')])
|
||||
def test_libdevice_scalar(dtype_str, expr, lib_path):
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip(f"test_libdevice_scalar currently has segfaults on ROCM")
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
|
@@ -2,7 +2,6 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
||||
@@ -126,10 +125,6 @@ def test_attention_fwd_bwd(
|
||||
batch_size=2,
|
||||
n_heads=2,
|
||||
):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 70:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
|
||||
# inputs
|
||||
qkv_shape = (batch_size, n_heads, n_ctx, 64)
|
||||
qkvs = [
|
||||
|
@@ -68,8 +68,6 @@ import triton._C.libtriton.triton as _triton
|
||||
)
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 70:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
if cc < 80 and DTYPE == "bfloat16":
|
||||
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
||||
if DTYPE == "bfloat16" and SPLIT_K != 1:
|
||||
|
@@ -7,6 +7,7 @@ import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -24,6 +25,12 @@ import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .tools.disasm import extract
|
||||
|
||||
def static_vars(**kwargs):
|
||||
def decorate(func):
|
||||
for k in kwargs:
|
||||
setattr(func, k, kwargs[k])
|
||||
return func
|
||||
return decorate
|
||||
|
||||
def str_to_ty(name):
|
||||
if name[0] == "*":
|
||||
@@ -880,31 +887,55 @@ def ptx_get_kernel_name(ptx: str) -> str:
|
||||
if line.startswith('// .globl'):
|
||||
return line.split()[-1]
|
||||
|
||||
@functools.lru_cache()
|
||||
def rocm_path_dir():
|
||||
return os.getenv("ROCM_PATH", default="/opt/rocm")
|
||||
|
||||
def _get_amdgpu_arch():
|
||||
try:
|
||||
rocminfo = subprocess.check_output(rocm_path_dir() + '/bin/rocminfo').decode()
|
||||
gfx_arch = re.search('Name:\\s+.*(gfx\\d+)', rocminfo)
|
||||
return gfx_arch.group(1).strip()
|
||||
except:
|
||||
return None
|
||||
|
||||
@static_vars(discovered_gfx_arch = _get_amdgpu_arch())
|
||||
def _compile(fn, signature: str, device: int = -1, constants=dict(),
|
||||
specialization=_triton.code_gen.instance_descriptor(),
|
||||
num_warps: int = 4, num_stages: int = 3, extern_libs=None,
|
||||
output: str = "ttgir", cc=0) -> Tuple[str, int, str]:
|
||||
print("compiler.py: _compile")
|
||||
print(f"\t{fn, signature, device, constants, specialization, num_warps, num_stages, extern_libs, output, cc}")
|
||||
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
# assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
|
||||
# triton-ir
|
||||
module, _ = make_triton_ir(fn, signature, specialization, constants)
|
||||
if output == "ttir":
|
||||
return module
|
||||
|
||||
assert output == "cubin"
|
||||
assert torch.version.hip is None
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
assert (output == "cubin" or output == "hsaco")
|
||||
if torch.version.hip is not None:
|
||||
backend = _triton.runtime.backend.ROCM
|
||||
else:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if extern_libs is None:
|
||||
extern_libs = dict()
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
||||
|
||||
# compile ttir
|
||||
if torch.version.hip is not None:
|
||||
gfx_arch = os.environ.get('MI_GPU_ARCH', _compile.discovered_gfx_arch)
|
||||
if gfx_arch is None:
|
||||
raise RuntimeError('AMDGCN gfx arch is not defined.')
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgcn(backend, module, device, num_warps, num_stages, extern_libs, gfx_arch)
|
||||
else:
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_ptx(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
||||
return asm, shared_mem, name
|
||||
|
||||
|
||||
def ty_to_cpp(ty):
|
||||
if ty[0] == '*':
|
||||
return "CUdeviceptr"
|
||||
return "hipDeviceptr_t"
|
||||
return {
|
||||
"i1": "int32_t",
|
||||
"i8": "int8_t",
|
||||
@@ -967,17 +998,18 @@ def generate_launcher(identifier, constants, signature):
|
||||
format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
|
||||
|
||||
# generate glue code
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
if torch.version.hip is not None:
|
||||
src = f"""
|
||||
#define __HIP_PLATFORM_AMD__
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <Python.h>
|
||||
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
static inline void gpuAssert(hipError_t code, const char *file, int line)
|
||||
{{
|
||||
if (code != CUDA_SUCCESS)
|
||||
if (code != HIP_SUCCESS)
|
||||
{{
|
||||
const char* prefix = "Triton Error [CUDA]: ";
|
||||
const char* str;
|
||||
cuGetErrorString(code, &str);
|
||||
const char* str = hipGetErrorString(code);
|
||||
char err[1024] = {{0}};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
@@ -987,20 +1019,20 @@ static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
|
||||
|
||||
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
|
||||
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||
if(gridX*gridY*gridZ > 0){{
|
||||
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
|
||||
hipModuleLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0);
|
||||
}}
|
||||
}}
|
||||
|
||||
|
||||
static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
|
||||
static inline hipDeviceptr_t getPointer(PyObject *obj, int idx) {{
|
||||
if (PyLong_Check(obj)) {{
|
||||
return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj);
|
||||
return (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
|
||||
}}
|
||||
if (obj == Py_None) {{
|
||||
return (CUdeviceptr)0;
|
||||
return (hipDeviceptr_t)0;
|
||||
}}
|
||||
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
||||
if(ptr){{
|
||||
@@ -1011,14 +1043,15 @@ static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
|
||||
if (!PyLong_Check(ret)) {{
|
||||
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
||||
}}
|
||||
return (CUdeviceptr)PyLong_AsUnsignedLongLong(ret);
|
||||
return (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
||||
}}
|
||||
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||
return (CUdeviceptr)0;
|
||||
return (hipDeviceptr_t)0;
|
||||
}}
|
||||
|
||||
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
// printf("launch(PyObject* self, PyObject* args)");
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
uint64_t _function;
|
||||
@@ -1039,7 +1072,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
Py_DECREF(new_args);
|
||||
}}
|
||||
|
||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
|
||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
|
||||
|
||||
if (launch_exit_hook != Py_None) {{
|
||||
PyObject *new_args = NULL;
|
||||
@@ -1126,7 +1159,7 @@ class CacheManager:
|
||||
os.rename(filepath + ".tmp", filepath)
|
||||
|
||||
|
||||
# utilities for generating and compiling C wrappers
|
||||
# utilties for generating and compiling C wrappers
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
@@ -1134,6 +1167,10 @@ def libcuda_dirs():
|
||||
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
|
||||
return [os.path.dirname(loc) for loc in locs]
|
||||
|
||||
@functools.lru_cache()
|
||||
def libhip_dirs():
|
||||
return ["/opt/rocm/lib"]
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def cuda_home_dirs():
|
||||
@@ -1141,6 +1178,15 @@ def cuda_home_dirs():
|
||||
return os.getenv("CUDA_HOME", default=default_dir)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def hip_home_dirs():
|
||||
default_dir = "/opt/rocm"
|
||||
return os.getenv("ROCM_HOME", default=default_dir)
|
||||
|
||||
@functools.lru_cache()
|
||||
def rocm_path_dir():
|
||||
return os.getenv("ROCM_PATH", default="/opt/rocm")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def quiet():
|
||||
old_stdout, old_stderr = sys.stdout, sys.stderr
|
||||
@@ -1152,8 +1198,15 @@ def quiet():
|
||||
|
||||
|
||||
def _build(name, src, srcdir):
|
||||
cuda_lib_dirs = libcuda_dirs()
|
||||
cu_include_dir = os.path.join(cuda_home_dirs(), "include")
|
||||
print("compiler.py: _build")
|
||||
print(f"\t{name, src, srcdir}")
|
||||
if torch.version.hip is not None:
|
||||
hip_lib_dirs = libhip_dirs()
|
||||
hip_include_dir = os.path.join(hip_home_dirs(), "include")
|
||||
else:
|
||||
cuda_lib_dirs = libcuda_dirs()
|
||||
cu_include_dir = os.path.join(cuda_home_dirs(), "include")
|
||||
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
||||
# try to avoid setuptools if possible
|
||||
@@ -1164,16 +1217,29 @@ def _build(name, src, srcdir):
|
||||
gcc = shutil.which("gcc")
|
||||
cc = gcc if gcc is not None else clang
|
||||
py_include_dir = get_paths()["include"]
|
||||
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
|
||||
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
|
||||
if torch.version.hip is not None:
|
||||
cc_cmd = [cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lamdhip64", "-o", so]
|
||||
cc_cmd += [f"-L{dir}" for dir in hip_lib_dirs]
|
||||
else:
|
||||
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
|
||||
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
|
||||
print("\t", ''.join(cc_cmd))
|
||||
ret = subprocess.check_call(cc_cmd)
|
||||
if ret == 0:
|
||||
print("ret:", ret)
|
||||
print(so)
|
||||
return so
|
||||
# fallback on setuptools
|
||||
extra_compile_args = []
|
||||
library_dirs = cuda_lib_dirs
|
||||
include_dirs = [srcdir, cu_include_dir]
|
||||
libraries = ['cuda']
|
||||
if torch.version.hip is not None:
|
||||
library_dirs = hip_lib_dirs
|
||||
include_dirs = [srcdir, hip_include_dir]
|
||||
libraries = ['rocm']
|
||||
else:
|
||||
library_dirs = cuda_lib_dirs
|
||||
include_dirs = [srcdir, cu_include_dir]
|
||||
libraries = ['cuda']
|
||||
|
||||
# extra arguments
|
||||
extra_link_args = []
|
||||
# create extension module
|
||||
@@ -1221,6 +1287,8 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta
|
||||
|
||||
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4,
|
||||
num_stages: int = 3, extern_libs=None, configs=None, cc=0, warm_cache_only=False):
|
||||
print("compiler.py: compile")
|
||||
print(f"\t{fn, signature, device, constants, num_warps, num_stages, extern_libs, configs, cc, warm_cache_only}")
|
||||
# we get the kernel, i.e. the first function generated in the module
|
||||
assert len(configs) == 1
|
||||
# cache manager
|
||||
@@ -1236,30 +1304,52 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = _build(fn.__name__, src_path, tmpdir)
|
||||
so = _build(fn.__name__, src_path, tmpdir) # build step
|
||||
with open(so, "rb") as f:
|
||||
so_cache_manager.put(f.read(), so_name, binary=True)
|
||||
|
||||
# retrieve cached shared object if it exists
|
||||
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
|
||||
fn_cache_manager = CacheManager(fn_cache_key)
|
||||
ptx_name = f"{name}.ptx"
|
||||
cubin_name = f"{name}.cubin"
|
||||
if torch.version.hip is not None:
|
||||
amdgcn_name = f"{name}.gcn"
|
||||
hasco_name = f"{name}.hsaco"
|
||||
assembly_name = amdgcn_name
|
||||
binary_name = hasco_name
|
||||
else:
|
||||
ptx_name = f"{name}.ptx"
|
||||
cubin_name = f"{name}.cubin"
|
||||
assembly_name = ptx_name
|
||||
binary_name = cubin_name
|
||||
|
||||
data_name = f"{name}.json"
|
||||
ttir_name = f"{name}.ttir"
|
||||
llir_name = f"{name}.llir"
|
||||
if not fn_cache_manager.has_file(cubin_name) or \
|
||||
if not fn_cache_manager.has_file(binary_name) or \
|
||||
not fn_cache_manager.has_file(data_name) or \
|
||||
not fn_cache_manager.has_file(ptx_name) or \
|
||||
not fn_cache_manager.has_file(assembly_name) or \
|
||||
not fn_cache_manager.has_file(ttir_name) or \
|
||||
not fn_cache_manager.has_file(llir_name):
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||
extern_libs, "cubin", cc)
|
||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||
fn_cache_manager.put(asm["cubin"], cubin_name)
|
||||
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
|
||||
|
||||
if torch.version.hip is not None:
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||
extern_libs, "hsaco", cc)
|
||||
# cache AMD assembly and binary
|
||||
fn_cache_manager.put(asm["hsaco_path"], binary_name, binary=False)
|
||||
fn_cache_manager.put(asm["amdgcn"], assembly_name, binary=False)
|
||||
else:
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||
extern_libs, "cubin", cc)
|
||||
# cache Nvidia assembly and binary
|
||||
fn_cache_manager.put(asm["cubin"], binary_name)
|
||||
fn_cache_manager.put(asm["ptx"], assembly_name, binary=False)
|
||||
|
||||
# cache triton and llvm ir
|
||||
fn_cache_manager.put(asm["ttir"], ttir_name, binary=False)
|
||||
fn_cache_manager.put(asm["llir"], llir_name, binary=False)
|
||||
|
||||
# cache metadata
|
||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
|
||||
|
||||
if warm_cache_only:
|
||||
@@ -1275,9 +1365,12 @@ class CompiledKernel:
|
||||
launch_exit_hook = None
|
||||
|
||||
def __init__(self, fn_name, so_path, cache_dir, device):
|
||||
print("compiler.py: CompiledKernel:__init__")
|
||||
print(f"\t{self, fn_name, so_path, cache_dir, device}")
|
||||
# initialize launcher
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
||||
print("spec:", spec)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
self.c_wrapper = getattr(mod, "launch")
|
||||
@@ -1289,16 +1382,25 @@ class CompiledKernel:
|
||||
self.num_stages = metadata["num_stages"]
|
||||
# initialize asm dict
|
||||
self.asm = dict()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
||||
self.asm["cubin"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||
self.asm["ptx"] = f.read()
|
||||
if torch.version.hip is not None:
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.hsaco"), "rb") as f:
|
||||
self.asm["hsaco_path"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.gcn"), "r") as f:
|
||||
self.asm["amdgcn"] = f.read()
|
||||
else:
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
||||
self.asm["cubin"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||
self.asm["ptx"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.llir"), "r") as f:
|
||||
self.asm["llir"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ttir"), "r") as f:
|
||||
self.asm["ttir"] = f.read()
|
||||
|
||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
if torch.version.hip is not None:
|
||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_hsaco(metadata["name"], self.asm["hsaco_path"], self.shared, device)
|
||||
else:
|
||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_cubin(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
self.fn_name = fn_name
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
|
@@ -768,7 +768,7 @@ def dot(input, other, trans_a=False, trans_b=False, allow_tf32=True, _builder=No
|
||||
"""
|
||||
Returns the matrix product of two blocks.
|
||||
|
||||
The two blocks must be two dimensions and have compatible inner dimensions.
|
||||
The two blocks must be two dimensionals and have compatible inner dimensions.
|
||||
|
||||
:param input: The first tensor to be multiplied.
|
||||
:type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
|
||||
|
@@ -58,7 +58,13 @@ def mulhi(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.int32, core.int32,): ("__nv_mulhi", core.int32),
|
||||
(core.uint32, core.uint32,): ("__nv_umulhi", core.uint32),
|
||||
(core.int64, core.int64,): ("__nv_mul64hi", core.int64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def mul64hi(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.int64, core.int64,): ("__nv_mul64hi", core.int64),
|
||||
(core.uint64, core.uint64,): ("__nv_umul64hi", core.uint64),
|
||||
}, _builder)
|
||||
|
||||
@@ -152,137 +158,261 @@ def saturatef(arg0, _builder=None):
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fma_rn(arg0, arg1, arg2, _builder=None):
|
||||
def fmaf_rn(arg0, arg1, arg2, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rn", core.float32),
|
||||
(core.float64, core.float64, core.float64,): ("__nv_fma_rn", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fmaf_rz(arg0, arg1, arg2, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rz", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fmaf_rd(arg0, arg1, arg2, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rd", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fmaf_ru(arg0, arg1, arg2, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ru", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fmaf_ieee_rn(arg0, arg1, arg2, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rn", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fmaf_ieee_rz(arg0, arg1, arg2, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rz", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fmaf_ieee_rd(arg0, arg1, arg2, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rd", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fmaf_ieee_ru(arg0, arg1, arg2, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_ru", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fma_rn(arg0, arg1, arg2, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||
{(core.float64, core.float64, core.float64,): ("__nv_fma_rn", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fma_rz(arg0, arg1, arg2, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rz", core.float32),
|
||||
(core.float64, core.float64, core.float64,): ("__nv_fma_rz", core.float64),
|
||||
{(core.float64, core.float64, core.float64,): ("__nv_fma_rz", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fma_rd(arg0, arg1, arg2, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rd", core.float32),
|
||||
(core.float64, core.float64, core.float64,): ("__nv_fma_rd", core.float64),
|
||||
{(core.float64, core.float64, core.float64,): ("__nv_fma_rd", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fma_ru(arg0, arg1, arg2, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
|
||||
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ru", core.float32),
|
||||
(core.float64, core.float64, core.float64,): ("__nv_fma_ru", core.float64),
|
||||
{(core.float64, core.float64, core.float64,): ("__nv_fma_ru", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fast_dividef(arg0, arg1, _builder=None):
|
||||
def fast_fdividef(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fast_fdividef", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def div_rn(arg0, arg1, _builder=None):
|
||||
def fdiv_rn(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fdiv_rn", core.float32),
|
||||
(core.float64, core.float64,): ("__nv_ddiv_rn", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def div_rz(arg0, arg1, _builder=None):
|
||||
def fdiv_rz(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fdiv_rz", core.float32),
|
||||
(core.float64, core.float64,): ("__nv_ddiv_rz", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def div_rd(arg0, arg1, _builder=None):
|
||||
def fdiv_rd(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fdiv_rd", core.float32),
|
||||
(core.float64, core.float64,): ("__nv_ddiv_rd", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def div_ru(arg0, arg1, _builder=None):
|
||||
def fdiv_ru(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fdiv_ru", core.float32),
|
||||
(core.float64, core.float64,): ("__nv_ddiv_ru", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def rcp_rn(arg0, _builder=None):
|
||||
def frcp_rn(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float32,): ("__nv_frcp_rn", core.float32),
|
||||
(core.float64,): ("__nv_drcp_rn", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def rcp_rz(arg0, _builder=None):
|
||||
def frcp_rz(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float32,): ("__nv_frcp_rz", core.float32),
|
||||
(core.float64,): ("__nv_drcp_rz", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def rcp_rd(arg0, _builder=None):
|
||||
def frcp_rd(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float32,): ("__nv_frcp_rd", core.float32),
|
||||
(core.float64,): ("__nv_drcp_rd", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def rcp_ru(arg0, _builder=None):
|
||||
def frcp_ru(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float32,): ("__nv_frcp_ru", core.float32),
|
||||
(core.float64,): ("__nv_drcp_ru", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def sqrt_rn(arg0, _builder=None):
|
||||
def fsqrt_rn(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float32,): ("__nv_fsqrt_rn", core.float32),
|
||||
(core.float64,): ("__nv_dsqrt_rn", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def sqrt_rz(arg0, _builder=None):
|
||||
def fsqrt_rz(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float32,): ("__nv_fsqrt_rz", core.float32),
|
||||
(core.float64,): ("__nv_dsqrt_rz", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def sqrt_rd(arg0, _builder=None):
|
||||
def fsqrt_rd(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float32,): ("__nv_fsqrt_rd", core.float32),
|
||||
(core.float64,): ("__nv_dsqrt_rd", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def sqrt_ru(arg0, _builder=None):
|
||||
def fsqrt_ru(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float32,): ("__nv_fsqrt_ru", core.float32),
|
||||
(core.float64,): ("__nv_dsqrt_ru", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def ddiv_rn(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_ddiv_rn", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def ddiv_rz(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_ddiv_rz", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def ddiv_rd(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_ddiv_rd", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def ddiv_ru(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_ddiv_ru", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def drcp_rn(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float64,): ("__nv_drcp_rn", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def drcp_rz(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float64,): ("__nv_drcp_rz", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def drcp_rd(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float64,): ("__nv_drcp_rd", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def drcp_ru(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float64,): ("__nv_drcp_ru", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def dsqrt_rn(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float64,): ("__nv_dsqrt_rn", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def dsqrt_rz(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float64,): ("__nv_dsqrt_rz", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def dsqrt_rd(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float64,): ("__nv_dsqrt_rd", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def dsqrt_ru(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float64,): ("__nv_dsqrt_ru", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -295,66 +425,114 @@ def sqrt(arg0, _builder=None):
|
||||
|
||||
|
||||
@extern.extern
|
||||
def add_rn(arg0, arg1, _builder=None):
|
||||
def dadd_rn(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_dadd_rn", core.float64),
|
||||
(core.float32, core.float32,): ("__nv_fadd_rn", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def add_rz(arg0, arg1, _builder=None):
|
||||
def dadd_rz(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_dadd_rz", core.float64),
|
||||
(core.float32, core.float32,): ("__nv_fadd_rz", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def add_rd(arg0, arg1, _builder=None):
|
||||
def dadd_rd(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_dadd_rd", core.float64),
|
||||
(core.float32, core.float32,): ("__nv_fadd_rd", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def add_ru(arg0, arg1, _builder=None):
|
||||
def dadd_ru(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_dadd_ru", core.float64),
|
||||
(core.float32, core.float32,): ("__nv_fadd_ru", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def mul_rn(arg0, arg1, _builder=None):
|
||||
def dmul_rn(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_dmul_rn", core.float64),
|
||||
(core.float32, core.float32,): ("__nv_fmul_rn", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def mul_rz(arg0, arg1, _builder=None):
|
||||
def dmul_rz(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_dmul_rz", core.float64),
|
||||
(core.float32, core.float32,): ("__nv_fmul_rz", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def mul_rd(arg0, arg1, _builder=None):
|
||||
def dmul_rd(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_dmul_rd", core.float64),
|
||||
(core.float32, core.float32,): ("__nv_fmul_rd", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def mul_ru(arg0, arg1, _builder=None):
|
||||
def dmul_ru(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_dmul_ru", core.float64),
|
||||
(core.float32, core.float32,): ("__nv_fmul_ru", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fadd_rd(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fadd_rd", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fadd_ru(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fadd_ru", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fmul_rd(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fmul_rd", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fmul_ru(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fmul_ru", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fadd_rn(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fadd_rn", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fadd_rz(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fadd_rz", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fmul_rn(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fmul_rn", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fmul_rz(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fmul_rz", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -446,13 +624,7 @@ def double2uint_ru(arg0, _builder=None):
|
||||
def int2double_rn(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.int32,): ("__nv_int2double_rn", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def uint2double_rn(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.uint32,): ("__nv_uint2double_rn", core.float64),
|
||||
(core.uint32,): ("__nv_uint2double_rn", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -516,6 +688,7 @@ def float2uint_ru(arg0, _builder=None):
|
||||
def int2float_rn(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.int32,): ("__nv_int2float_rn", core.float32),
|
||||
(core.uint32,): ("__nv_uint2float_rn", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -523,6 +696,7 @@ def int2float_rn(arg0, _builder=None):
|
||||
def int2float_rz(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.int32,): ("__nv_int2float_rz", core.float32),
|
||||
(core.uint32,): ("__nv_uint2float_rz", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -530,6 +704,7 @@ def int2float_rz(arg0, _builder=None):
|
||||
def int2float_rd(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.int32,): ("__nv_int2float_rd", core.float32),
|
||||
(core.uint32,): ("__nv_uint2float_rd", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -537,34 +712,7 @@ def int2float_rd(arg0, _builder=None):
|
||||
def int2float_ru(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.int32,): ("__nv_int2float_ru", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def uint2float_rn(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.uint32,): ("__nv_uint2float_rn", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def uint2float_rz(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.uint32,): ("__nv_uint2float_rz", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def uint2float_rd(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.uint32,): ("__nv_uint2float_rd", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def uint2float_ru(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.uint32,): ("__nv_uint2float_ru", core.float32),
|
||||
(core.uint32,): ("__nv_uint2float_ru", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -705,6 +853,7 @@ def double2ull_ru(arg0, _builder=None):
|
||||
def ll2float_rn(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.int64,): ("__nv_ll2float_rn", core.float32),
|
||||
(core.uint64,): ("__nv_ull2float_rn", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -712,6 +861,7 @@ def ll2float_rn(arg0, _builder=None):
|
||||
def ll2float_rz(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.int64,): ("__nv_ll2float_rz", core.float32),
|
||||
(core.uint64,): ("__nv_ull2float_rz", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -719,6 +869,7 @@ def ll2float_rz(arg0, _builder=None):
|
||||
def ll2float_rd(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.int64,): ("__nv_ll2float_rd", core.float32),
|
||||
(core.uint64,): ("__nv_ull2float_rd", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -726,34 +877,7 @@ def ll2float_rd(arg0, _builder=None):
|
||||
def ll2float_ru(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.int64,): ("__nv_ll2float_ru", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def ull2float_rn(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.uint64,): ("__nv_ull2float_rn", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def ull2float_rz(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.uint64,): ("__nv_ull2float_rz", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def ull2float_rd(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.uint64,): ("__nv_ull2float_rd", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def ull2float_ru(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.uint64,): ("__nv_ull2float_ru", core.float32),
|
||||
(core.uint64,): ("__nv_ull2float_ru", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -761,6 +885,7 @@ def ull2float_ru(arg0, _builder=None):
|
||||
def ll2double_rn(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.int64,): ("__nv_ll2double_rn", core.float64),
|
||||
(core.uint64,): ("__nv_ull2double_rn", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -768,6 +893,7 @@ def ll2double_rn(arg0, _builder=None):
|
||||
def ll2double_rz(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.int64,): ("__nv_ll2double_rz", core.float64),
|
||||
(core.uint64,): ("__nv_ull2double_rz", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -775,6 +901,7 @@ def ll2double_rz(arg0, _builder=None):
|
||||
def ll2double_rd(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.int64,): ("__nv_ll2double_rd", core.float64),
|
||||
(core.uint64,): ("__nv_ull2double_rd", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -782,34 +909,7 @@ def ll2double_rd(arg0, _builder=None):
|
||||
def ll2double_ru(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.int64,): ("__nv_ll2double_ru", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def ull2double_rn(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.uint64,): ("__nv_ull2double_rn", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def ull2double_rz(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.uint64,): ("__nv_ull2double_rz", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def ull2double_rd(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.uint64,): ("__nv_ull2double_rd", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def ull2double_ru(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.uint64,): ("__nv_ull2double_ru", core.float64),
|
||||
(core.uint64,): ("__nv_ull2double_ru", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -817,6 +917,7 @@ def ull2double_ru(arg0, _builder=None):
|
||||
def int_as_float(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.int32,): ("__nv_int_as_float", core.float32),
|
||||
(core.uint32,): ("__nv_uint_as_float", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -827,13 +928,6 @@ def float_as_int(arg0, _builder=None):
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def uint_as_float(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.uint32,): ("__nv_uint_as_float", core.float32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def float_as_uint(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
@@ -912,9 +1006,11 @@ def fast_log10f(arg0, _builder=None):
|
||||
|
||||
|
||||
@extern.extern
|
||||
def fast_powf(arg0, arg1, _builder=None):
|
||||
def pow(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fast_powf", core.float32),
|
||||
(core.float32, core.float32,): ("__nv_powf", core.float32),
|
||||
(core.float64, core.float64,): ("__nv_pow", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -935,39 +1031,35 @@ def rhadd(arg0, arg1, _builder=None):
|
||||
|
||||
|
||||
@extern.extern
|
||||
def sub_rn(arg0, arg1, _builder=None):
|
||||
def fsub_rn(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fsub_rn", core.float32),
|
||||
(core.float64, core.float64,): ("__nv_dsub_rn", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def sub_rz(arg0, arg1, _builder=None):
|
||||
def fsub_rz(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fsub_rz", core.float32),
|
||||
(core.float64, core.float64,): ("__nv_dsub_rz", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def sub_rd(arg0, arg1, _builder=None):
|
||||
def fsub_rd(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fsub_rd", core.float32),
|
||||
(core.float64, core.float64,): ("__nv_dsub_rd", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def sub_ru(arg0, arg1, _builder=None):
|
||||
def fsub_ru(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.float32,): ("__nv_fsub_ru", core.float32),
|
||||
(core.float64, core.float64,): ("__nv_dsub_ru", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def rsqrt_rn(arg0, _builder=None):
|
||||
def frsqrt_rn(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float32,): ("__nv_frsqrt_rn", core.float32),
|
||||
}, _builder)
|
||||
@@ -1006,18 +1098,16 @@ def nearbyint(arg0, _builder=None):
|
||||
|
||||
|
||||
@extern.extern
|
||||
def isnan(arg0, _builder=None):
|
||||
def isnanf(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float32,): ("__nv_isnanf", core.int32),
|
||||
(core.float64,): ("__nv_isnand", core.int32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def signbit(arg0, _builder=None):
|
||||
def signbitf(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float32,): ("__nv_signbitf", core.int32),
|
||||
(core.float64,): ("__nv_signbitd", core.int32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -1037,10 +1127,9 @@ def finitef(arg0, _builder=None):
|
||||
|
||||
|
||||
@extern.extern
|
||||
def isinf(arg0, _builder=None):
|
||||
def isinff(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float32,): ("__nv_isinff", core.int32),
|
||||
(core.float64,): ("__nv_isinfd", core.int32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -1461,12 +1550,10 @@ def fma(arg0, arg1, arg2, _builder=None):
|
||||
|
||||
|
||||
@extern.extern
|
||||
def pow(arg0, arg1, _builder=None):
|
||||
def powi(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float32, core.int32,): ("__nv_powif", core.float32),
|
||||
(core.float64, core.int32,): ("__nv_powi", core.float64),
|
||||
(core.float32, core.float32,): ("__nv_powf", core.float32),
|
||||
(core.float64, core.float64,): ("__nv_pow", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@@ -1518,8 +1605,57 @@ def logb(arg0, _builder=None):
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def signbitd(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float64,): ("__nv_signbitd", core.int32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def isfinited(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float64,): ("__nv_isfinited", core.int32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def isinfd(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float64,): ("__nv_isinfd", core.int32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def isnand(arg0, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
|
||||
{(core.float64,): ("__nv_isnand", core.int32),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def dsub_rn(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_dsub_rn", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def dsub_rz(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_dsub_rz", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def dsub_ru(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_dsub_ru", core.float64),
|
||||
}, _builder)
|
||||
|
||||
|
||||
@extern.extern
|
||||
def dsub_rd(arg0, arg1, _builder=None):
|
||||
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
|
||||
{(core.float64, core.float64,): ("__nv_dsub_rd", core.float64),
|
||||
}, _builder)
|
||||
|
@@ -6,6 +6,8 @@ from . import core as tl
|
||||
from triton._C.libtriton.triton import ir
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
# Create custom exception that prints message "hello"
|
||||
class IncompatibleTypeErrorimpl(Exception):
|
||||
def __init__(self, type_a, type_b):
|
||||
@@ -969,6 +971,11 @@ def dot(a: tl.tensor,
|
||||
trans_b: bool,
|
||||
allow_tf32: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
|
||||
if torch.version.hip is not None:
|
||||
a = cast(a, tl.float32, builder)
|
||||
b = cast(b, tl.float32, builder)
|
||||
|
||||
in_a = 1 if not trans_a else 0
|
||||
in_b = 1 if trans_b else 0
|
||||
assert a.type.is_block() and b.type.is_block()
|
||||
|
@@ -270,7 +270,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
# build stub signature -- includes arguments that are specialized
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
||||
if not warmup:
|
||||
|
@@ -115,9 +115,7 @@ def nvsmi(attrs):
|
||||
return ret
|
||||
|
||||
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
|
||||
percentiles=(0.5, 0.2, 0.8),
|
||||
record_clocks=False, fast_flush=False):
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False):
|
||||
"""
|
||||
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
|
||||
the 20-th and 80-th performance percentile.
|
||||
@@ -132,8 +130,6 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
|
||||
:type grad_to_none: torch.tensor, optional
|
||||
:param percentiles: Performance percentile to return in addition to the median.
|
||||
:type percentiles: list[float]
|
||||
:param fast_flush: Use faster kernel to flush L2 between measurements
|
||||
:type fast_flush: bool
|
||||
"""
|
||||
|
||||
# Estimate the runtime of the function
|
||||
@@ -155,10 +151,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
|
||||
# doesn't contain any input data before the run
|
||||
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
if fast_flush:
|
||||
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
else:
|
||||
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
||||
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
||||
# Warm-up
|
||||
for _ in range(n_warmup):
|
||||
fn()
|
||||
|
@@ -1,24 +1,10 @@
|
||||
import argparse
|
||||
import subprocess
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
class Symbol:
|
||||
_name: str
|
||||
_op_name: str
|
||||
_ret_type: str
|
||||
_arg_names: List[str]
|
||||
_arg_types: List[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
op_name: str,
|
||||
ret_type: str,
|
||||
arg_names: List[str],
|
||||
arg_types: List[str],
|
||||
) -> None:
|
||||
def __init__(self, name: str, op_name: str, ret_type: str, arg_names: list, arg_types: list) -> None:
|
||||
'''
|
||||
A symbol is a function declaration.
|
||||
|
||||
@@ -31,31 +17,31 @@ class Symbol:
|
||||
self._name = name
|
||||
self._op_name = op_name
|
||||
self._ret_type = ret_type
|
||||
self._arg_names = list(arg_names)
|
||||
self._arg_types = list(arg_types)
|
||||
self._arg_names = arg_names
|
||||
self._arg_types = arg_types
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def op_name(self) -> str:
|
||||
def op_name(self):
|
||||
return self._op_name
|
||||
|
||||
@property
|
||||
def ret_type(self) -> str:
|
||||
def ret_type(self):
|
||||
return self._ret_type
|
||||
|
||||
@property
|
||||
def arg_names(self) -> List[str]:
|
||||
def arg_names(self):
|
||||
return self._arg_names
|
||||
|
||||
@property
|
||||
def arg_types(self) -> List[str]:
|
||||
def arg_types(self):
|
||||
return self._arg_types
|
||||
|
||||
|
||||
def convert_type(type_str) -> Optional[str]:
|
||||
def convert_type(type_str):
|
||||
if type_str == "i32":
|
||||
return "int32"
|
||||
elif type_str == "u32":
|
||||
@@ -73,7 +59,7 @@ def convert_type(type_str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def to_unsigned(type_str) -> str:
|
||||
def to_unsigned(type_str):
|
||||
if type_str == "int32":
|
||||
return "uint32"
|
||||
elif type_str == "int64":
|
||||
@@ -83,19 +69,7 @@ def to_unsigned(type_str) -> str:
|
||||
|
||||
|
||||
class ExternLibrary(ABC):
|
||||
_name: str
|
||||
_path: str
|
||||
_symbols: Dict[str, Symbol]
|
||||
_format: bool
|
||||
_grouping: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
path: str,
|
||||
format: bool = True,
|
||||
grouping: bool = True,
|
||||
) -> None:
|
||||
def __init__(self, name: str, path: str, format: bool = True, grouping: bool = True) -> None:
|
||||
'''
|
||||
Abstract class for extern library.
|
||||
|
||||
@@ -106,34 +80,34 @@ class ExternLibrary(ABC):
|
||||
self._name = name
|
||||
self._path = path
|
||||
self._symbols = {}
|
||||
self._format = format
|
||||
self._format = True
|
||||
self._grouping = grouping
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
def path(self):
|
||||
return self._path
|
||||
|
||||
@property
|
||||
def symbols(self) -> Dict[str, Symbol]:
|
||||
def symbols(self):
|
||||
return self._symbols
|
||||
|
||||
@property
|
||||
def grouping(self) -> bool:
|
||||
def grouping(self):
|
||||
return self._grouping
|
||||
|
||||
@abstractmethod
|
||||
def parse_symbols(self, input_file) -> None:
|
||||
def parse_symbols(self, input_file):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _output_stubs(self) -> str:
|
||||
pass
|
||||
|
||||
def generate_stub_file(self, output_dir) -> None:
|
||||
def generate_stub_file(self, output_dir):
|
||||
file_str = self._output_stubs()
|
||||
if file_str is None or len(file_str) == 0:
|
||||
raise Exception("file_str is empty")
|
||||
@@ -149,8 +123,6 @@ class ExternLibrary(ABC):
|
||||
|
||||
|
||||
class Libdevice(ExternLibrary):
|
||||
_symbol_groups: Dict[str, List[Symbol]]
|
||||
|
||||
def __init__(self, path) -> None:
|
||||
'''
|
||||
Constructor for Libdevice.
|
||||
@@ -160,7 +132,7 @@ class Libdevice(ExternLibrary):
|
||||
super().__init__("libdevice", path)
|
||||
self._symbol_groups = {}
|
||||
|
||||
def _extract_symbol(self, line) -> Optional[Symbol]:
|
||||
def _extract_symbol(self, line):
|
||||
# Extract symbols from line in the following format:
|
||||
# "define [internal] <ret_type> @<name>(<arg_types>,)"
|
||||
entries = line.split("@")
|
||||
@@ -177,9 +149,6 @@ class Libdevice(ExternLibrary):
|
||||
func_strs = func_str.split("(")
|
||||
func_name = func_strs[0].replace("@", "")
|
||||
op_name = func_name.replace("__nv_", "")
|
||||
# To filter some interfaces unlisted in NVIDIA's official documents.
|
||||
if 'ieee' in op_name:
|
||||
return None
|
||||
# Get arg_types
|
||||
arg_strs = func_strs[1].split(",")
|
||||
arg_types = []
|
||||
@@ -202,77 +171,66 @@ class Libdevice(ExternLibrary):
|
||||
arg_types[i] = to_unsigned(arg_type)
|
||||
return Symbol(func_name, op_name, ret_type, arg_names, arg_types)
|
||||
|
||||
def _group_symbols(self) -> None:
|
||||
def _group_symbols(self):
|
||||
symbol_set = {}
|
||||
for symbol in self._symbols.values():
|
||||
op_name = symbol.op_name
|
||||
symbol_set[op_name] = symbol
|
||||
|
||||
# Group functions together by renaming.
|
||||
renaming = {
|
||||
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh',
|
||||
'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn',
|
||||
'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru',
|
||||
'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin',
|
||||
'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2',
|
||||
'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt',
|
||||
'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign',
|
||||
'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi',
|
||||
'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
|
||||
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn',
|
||||
'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru',
|
||||
'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf',
|
||||
'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx',
|
||||
'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10',
|
||||
'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs',
|
||||
'fabs': 'abs', 'fast_fdividef': 'fast_dividef',
|
||||
'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor',
|
||||
'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn',
|
||||
'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod',
|
||||
'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb',
|
||||
'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan',
|
||||
'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
|
||||
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint',
|
||||
'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10',
|
||||
'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb',
|
||||
'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max',
|
||||
'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min',
|
||||
'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd',
|
||||
'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru',
|
||||
'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz',
|
||||
'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi',
|
||||
'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter',
|
||||
'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf',
|
||||
'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow',
|
||||
'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd',
|
||||
'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru',
|
||||
'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz',
|
||||
'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
|
||||
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d',
|
||||
'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn',
|
||||
'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit',
|
||||
'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh',
|
||||
'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd',
|
||||
'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn',
|
||||
'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz',
|
||||
'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd',
|
||||
'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru',
|
||||
'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
|
||||
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc',
|
||||
'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn'
|
||||
}
|
||||
|
||||
# The following cases are grouped together:
|
||||
# op_name, <u/ull/ll>op_name<ll/f/i>
|
||||
for symbol in self._symbols.values():
|
||||
op_name = symbol.op_name
|
||||
if op_name in renaming:
|
||||
op_name = renaming[op_name]
|
||||
if "max" in op_name:
|
||||
op_name = "max"
|
||||
elif "min" in op_name:
|
||||
op_name = "min"
|
||||
elif "abs" in op_name:
|
||||
op_name = "abs"
|
||||
elif "pow" in op_name and "fast" in op_name:
|
||||
op_name = "pow"
|
||||
elif "round" in op_name:
|
||||
if "llround" in op_name:
|
||||
op_name = "llround"
|
||||
else:
|
||||
op_name = "round"
|
||||
elif "rint" in op_name:
|
||||
if "llrint" in op_name:
|
||||
op_name = "llrint"
|
||||
else:
|
||||
op_name = "rint"
|
||||
elif op_name.startswith("ull"):
|
||||
if "2" not in op_name:
|
||||
# e.g., ullmax->max
|
||||
op_name = op_name[3:]
|
||||
else:
|
||||
# e.g., ull2double->ll2double
|
||||
op_name = op_name[1:]
|
||||
elif op_name.startswith("u"):
|
||||
if "2" not in op_name:
|
||||
# e.g., uhadd->hadd
|
||||
op_name = op_name[1:]
|
||||
else:
|
||||
# e.g., uint2double_rn->int2double_rn
|
||||
op_name = op_name[1:]
|
||||
elif op_name.startswith("ll"):
|
||||
if "2" not in op_name:
|
||||
# e.g., llmax->max
|
||||
op_name = op_name[2:]
|
||||
elif op_name.endswith("ll"):
|
||||
op_name = op_name[:-2]
|
||||
elif op_name.endswith("f"):
|
||||
op_name = op_name[:-1]
|
||||
if op_name in symbol_set:
|
||||
# Update op_name only if there's an existing symbol
|
||||
symbol._op_name = op_name
|
||||
else:
|
||||
op_name = symbol._op_name
|
||||
if op_name in self._symbol_groups:
|
||||
self._symbol_groups[op_name].append(symbol)
|
||||
else:
|
||||
self._symbol_groups[op_name] = [symbol]
|
||||
|
||||
def parse_symbols(self, input_file) -> None:
|
||||
def parse_symbols(self, input_file):
|
||||
if len(self.symbols) > 0:
|
||||
return
|
||||
output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines()
|
||||
@@ -284,7 +242,7 @@ class Libdevice(ExternLibrary):
|
||||
|
||||
self._group_symbols()
|
||||
|
||||
def _output_stubs(self) -> str:
|
||||
def _output_stubs(self):
|
||||
# Generate python functions in the following format:
|
||||
# @extern.extern
|
||||
# def <op_name>(<args>, _builder=None):
|
||||
@@ -292,7 +250,7 @@ class Libdevice(ExternLibrary):
|
||||
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
|
||||
import_str = "from . import core, extern\n"
|
||||
import_str += "import os\n"
|
||||
header_str = "LIBDEVICE_PATH = os.path.dirname(\n\tos.path.abspath(__file__)) + \"/libdevice.10.bc\"\n"
|
||||
header_str = "LIBDEVICE_PATH = os.path.dirname(os.path.abspath(__file__)) + \"/libdevice.10.bc\"\n"
|
||||
func_str = ""
|
||||
for symbols in self._symbol_groups.values():
|
||||
func_str += "@extern.extern\n"
|
||||
@@ -325,10 +283,7 @@ class Libdevice(ExternLibrary):
|
||||
|
||||
|
||||
class LLVMDisassembler:
|
||||
_path: str
|
||||
_ll_file: str
|
||||
|
||||
def __init__(self, path) -> None:
|
||||
def __init__(self, path):
|
||||
'''
|
||||
Invoke llvm-dis to disassemble the given file.
|
||||
|
||||
@@ -337,28 +292,23 @@ class LLVMDisassembler:
|
||||
self._path = path
|
||||
self._ll_file = "/tmp/extern_lib.ll"
|
||||
|
||||
def disasm(self, lib_path: str) -> None:
|
||||
def disasm(self, lib_path):
|
||||
subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
|
||||
stdout=subprocess.PIPE).communicate()
|
||||
|
||||
@property
|
||||
def ll_file(self) -> str:
|
||||
def ll_file(self):
|
||||
return self._ll_file
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
def path(self):
|
||||
return self._path
|
||||
|
||||
|
||||
extern_libs = ["libdevice"]
|
||||
|
||||
|
||||
def build(
|
||||
llvm_dis_path: str,
|
||||
lib_path: str,
|
||||
lib_name: str,
|
||||
output_dir: str,
|
||||
) -> None:
|
||||
def build(llvm_dis_path, lib_path, lib_name, output_dir):
|
||||
'''
|
||||
Interface function to build the library file.
|
||||
|
||||
@@ -381,10 +331,10 @@ def build(
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis")
|
||||
parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library")
|
||||
parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library")
|
||||
parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/")
|
||||
parser.add_argument("-llvm", dest="llvm_dis_path", help="path to llvm-dis", default="llvm-dis")
|
||||
parser.add_argument("--lib-path", dest="lib_path", help="path to the extern library")
|
||||
parser.add_argument("--lib-name", dest="lib_name", help="name of the extern library")
|
||||
parser.add_argument("-o", dest="output_dir", help="output file path", default="/tmp/")
|
||||
args = parser.parse_args()
|
||||
|
||||
build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir)
|
||||
|
@@ -7,7 +7,7 @@ Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html re
|
||||
|
||||
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
|
||||
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
|
||||
Using triton, you can simply call `tl.libdevice.asin`.
|
||||
Using triton, you can simply call `tl.libdevice.asinf`.
|
||||
triton automatically selects the correct underlying device function to invoke based on input and output types.
|
||||
"""
|
||||
|
||||
|
14
triton_rocm_20-52.Dockerfile
Normal file
14
triton_rocm_20-52.Dockerfile
Normal file
@@ -0,0 +1,14 @@
|
||||
FROM rocm/pytorch:rocm5.2.3_ubuntu20.04_py3.7_pytorch_1.12.1
|
||||
|
||||
# build triton
|
||||
RUN export TRITON_USE_ROCM=ON MI_GPU_ARCH=gfx90a
|
||||
|
||||
# Unit Tests
|
||||
# to run unit tests
|
||||
# 1. build this Dockerfile
|
||||
# docker build --build-arg -f triton_rocm_20-52.Dockerfile -t triton_rocm52 .
|
||||
# 2. run docker container
|
||||
# docker run -it --rm --network=host --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --name triton --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri triton_rocm52:latest
|
||||
# 3. run core unit tests on a rocm machine
|
||||
# cd ~/triton/python
|
||||
# pytest --verbose test/unit/language/test_core.py | tee test_core.log
|
Reference in New Issue
Block a user