59 Commits

Author SHA1 Message Date
rsanthanam-amd
46fd581b0a Merge pull request #29 from ROCmSoftwarePlatform/parse_amdgcn_from_rocminfo
Changes to eliminate the need for the MI_GPU_ARCH environment variable.
2022-11-18 12:53:25 -06:00
Rohit Santhanam
8cc448d92e Changes to eliminate the need for the MI_GPU_ARCH environment variable.
The AMDGPU arch is now parsed out of the rocminfo dump.
2022-11-18 18:51:57 +00:00
Michael Melesse
9a9fabbba9 Merge pull request #22 from ROCmSoftwarePlatform/IFU_11_1_2022
IFU 11/1/2022
2022-11-01 14:27:33 -04:00
Michael Melesse
15886b5ffc skip segfault 2022-11-01 17:52:18 +00:00
Michael Melesse
d5830b4b6a Merge branch 'master' into IFU_11_1_2022 2022-11-01 17:29:10 +00:00
Michael Melesse
bba1579485 remove scripts 2022-11-01 17:24:35 +00:00
rsanthanam-amd
cc6b5180c7 Merge pull request #19 from ROCmSoftwarePlatform/unskip_test_reduce
reduce the skips for test_reduce functions
2022-11-01 11:05:18 -05:00
Michael Melesse
dfad6bdf36 reduce the skips for test_reduce functions 2022-11-01 15:00:12 +00:00
rsanthanam-amd
f3bcbcfde6 Merge pull request #18 from ROCmSoftwarePlatform/fix_test_dot
Fix 6/7 test dot
2022-11-01 09:34:37 -05:00
Michael Melesse
7ec29a7453 revert scripts 2022-11-01 14:22:33 +00:00
Michael Melesse
4fb9d4904e fix 6/7 dot tests 2022-11-01 14:18:06 +00:00
Michael Melesse
4f3e2d6ed7 Merge branch 'rocm52_fixes_IFU' into fix_test_dot 2022-10-31 19:24:45 +00:00
rsanthanam-amd
fecc7ce248 Fix for test_bitwise subtests for ROCm. (#16)
The issue was that the kernel names were colliding with each other in
the cache.  Since the kernel names were based on the date and time, the
kernels were getting compiled so fast that a subsequent kernel would end
up with the same name as the previous one and would therefore overwrite
it in the cache.

It seems to run the same test multiple times but the subsequent runs
would end up using the wrong kernel because of the collisions.

It is fixed by appending a randomly generated alphanumeric string to
keep the kernel names unique.
2022-10-31 15:24:08 -04:00
Michael Melesse
277b712284 save changes 2022-10-31 19:11:58 +00:00
Michael Melesse
d024f0cfb8 update test_dot to use float 32 2022-10-31 18:58:10 +00:00
Michael Melesse
1811791665 add failures in report 2022-10-31 18:39:58 +00:00
Michael Melesse
9b3f2487b5 fix minor bug 2022-10-31 18:33:47 +00:00
rsanthanam-amd
14730a2352 Merge pull request #15 from ROCmSoftwarePlatform/bfloat_enable
unskip most bfloat tests
2022-10-31 13:10:30 -05:00
Michael Melesse
15683986cd unskip most bfloat tests 2022-10-31 18:04:54 +00:00
rsanthanam-amd
48fcd8c987 Merge pull request #14 from ROCmSoftwarePlatform/fix_vectorization
fix test_vectorization and test_load_cache_modifier
2022-10-28 16:12:57 -05:00
Michael Melesse
8d9572bc63 add similar fixes two addition tests 2022-10-28 20:34:58 +00:00
Michael Melesse
ffb30cdc52 skip ptx assert 2022-10-28 20:23:11 +00:00
Michael Melesse
7fce2bc5f1 add print_llvm_module 2022-10-28 20:07:35 +00:00
rsanthanam-amd
531ef18cb6 Fix for binop % (mod) unit test failures. (#13)
If the either data type if fp, then fmod should be used for the
reference computation.
2022-10-28 15:06:17 -04:00
Michael Melesse
5f0d90db7e tab prints 2022-10-28 19:05:42 +00:00
Michael Melesse
03ae41b310 add print helper 2022-10-28 17:55:28 +00:00
Michael Melesse
bd61338b31 update scripts 2022-10-28 17:48:26 +00:00
Michael Melesse
6e50f8b2c0 print irs 2022-10-28 17:46:52 +00:00
Michael Melesse
aa556d4f1b update script 2022-10-26 21:51:15 +00:00
Michael Melesse
61e88efb23 ignore logs 2022-10-26 21:42:41 +00:00
Michael Melesse
ed9638801a fix for test_cast 2022-10-26 21:34:58 +00:00
Michael Melesse
8ecab462f6 skip segfaults on ROCM 2022-10-26 20:46:47 +00:00
Michael Melesse
648e4cfe89 skip test_atomic_rmw on rocm 2022-10-26 18:22:23 +00:00
Michael Melesse
abe0d3e1b1 cast to amd device when as_nvidia shows up 2022-10-26 18:12:18 +00:00
Michael Melesse
4464dfcc18 save scripts 2022-10-26 17:42:58 +00:00
Michael Melesse
0cae0168ec fix bfloat failure 2022-10-26 17:40:28 +00:00
Michael Melesse
88d57ef9c9 add cache print 2022-10-26 17:19:30 +00:00
Michael Melesse
39381d99f8 send amdgcn to cache 2022-10-26 17:18:33 +00:00
Michael Melesse
df925f7187 add cache print script 2022-10-25 20:48:36 +00:00
Michael Melesse
e84297ca79 print cache 2022-10-25 20:44:42 +00:00
Michael Melesse
61c85c18b2 try to load binary 2022-10-25 20:29:43 +00:00
Michael Melesse
da5c24ffcb just clean cache 2022-10-25 20:27:13 +00:00
Michael Melesse
09302f0106 fix linking bug 2022-10-25 18:31:10 +00:00
Michael Melesse
9184b5cf65 add prints 2022-10-24 18:28:28 +00:00
Michael Melesse
8da4323514 write hipmodule bytes 2022-10-24 17:58:25 +00:00
Michael Melesse
eb89e9bdd9 fix generator.cc: generator::visit_function: segfault 2022-10-24 17:41:20 +00:00
Michael Melesse
56a06f7a06 add debug steps 2022-10-21 20:17:30 +00:00
Michael Melesse
6a31c43774 update batcktrace 2022-10-21 19:56:19 +00:00
Michael Melesse
8785793445 fix typo 2022-10-21 17:58:38 +00:00
Michael Melesse
d022f5cf2c add compiling back to gcn 2022-10-21 17:54:31 +00:00
Michael Melesse
4624fd4e1d save compiler 2022-10-19 20:39:32 +00:00
Michael Melesse
41144f927f fix hip launch 2022-10-17 20:41:28 +00:00
Michael Melesse
4d6d4c9431 hip src 2022-10-17 20:18:44 +00:00
Michael Melesse
32dbc08c05 fix llvm build errors 2022-10-17 18:29:15 +00:00
Michael Melesse
4f21501def add fixes 2022-10-17 18:21:14 +00:00
Michael Melesse
5c548fb57e Merge branch 'master' into rcom52_fixes 2022-10-17 17:53:48 +00:00
Michael Melesse
fa4d0fd1ef add scripts 2022-10-17 17:28:48 +00:00
Daniil Fukalov
406d03bfaf Improve ROCm support. (#780)
- updates to support ROCm 5.2
- workarounds in tests where NV tools were used unconditionally
- implemented `get_num_blocks()` and `add_memfence()` for AMD GPU
- backported from history some atomics
- added bf16 support
- minor warnings cleanup
- added dockerfile to run on a ROCm enabled machine

Co-authored-by: B1tway <andrew.shukshov@gmail.com>
Co-authored-by: Andrey Shukshov <36711069+B1tway@users.noreply.github.com>
2022-10-14 11:33:42 -07:00
Michael Melesse
94d5c2e8b5 [ROCM] enable matmul(dot) and others (#391) 2021-12-13 12:28:15 -08:00
35 changed files with 1532 additions and 661 deletions

3
.gitignore vendored
View File

@@ -9,4 +9,5 @@ python/triton/_C/libtriton.pyd
python/triton/_C/libtriton.so
.vscode
.vs
.vs
log_*

View File

@@ -3,6 +3,13 @@ include(ExternalProject)
set(CMAKE_CXX_STANDARD 17)
if(NOT TRITON_LLVM_BUILD_DIR)
set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
endif()
set(TRITON_USE_ROCM "$ENV{TRITON_USE_ROCM}")
set(TRITON_ROCM_DEBUG "$ENV{TRITON_ROCM_DEBUG}")
project(triton)
include(CTest)
if(NOT WIN32)
@@ -35,7 +42,11 @@ if(WIN32)
add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
if (TRITON_USE_ROCM)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-unused-result -Wno-attributes")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
endif()
##########
@@ -135,6 +146,13 @@ if(BUILD_PYTHON_MODULE)
endif()
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
if (TRITON_USE_ROCM)
add_definitions(-DUSE_ROCM)
endif()
if (TRITON_ROCM_DEBUG)
add_definitions(-DDEBUG_ROCM)
endif()
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC})
endif()

163
include/print_helper.h Executable file
View File

@@ -0,0 +1,163 @@
#pragma once
#ifndef _PRINT_IR_H_
#define _PRINT_IR_H_
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IR/Module.h"
#include <fstream>
#include <iostream>
#include <sstream>
#include "triton/ir/module.h"
#include "triton/ir/print.h"
#include <iomanip>
#define PRINT_CURRENT_FUNCTION() std::cout << __FILE__ << ":" << __LINE__ << ":" << __FUNCTION__ << std::endl;
static int print_count = 0;
inline std::string return_current_time_and_date()
{
auto now = std::chrono::system_clock::now();
auto in_time_t = std::chrono::system_clock::to_time_t(now);
std::stringstream ss;
ss << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d--%I-%M-%S");
return ss.str();
}
template <typename T>
inline void print_vector(std::vector<T> &vec, std::string name = "")
{
std::cout << name << ": ";
for (auto v : vec)
{
std::cout << v << ", ";
}
std::cout << '\b';
std::cout << std::endl;
}
// dump llvm ir to tmp file
inline std::string print_llvm_module(llvm::Module *llvm_module, bool print_to_cout = true)
{
std::cout << "\t" << "print_llvm_module" << std::endl;
// get module as a string
std::error_code ec;
std::string mod_string;
std::unique_ptr<llvm::raw_string_ostream> ir_ss(
new llvm::raw_string_ostream(mod_string));
llvm_module->print(*ir_ss, nullptr);
// print module
if (print_to_cout)
{
if (!mod_string.empty())
std::cout << "\t" << mod_string << std::endl;
else
std::cout << "\t" << llvm_module->getModuleIdentifier() << ": "
<< "is empty" << std::endl;
}
return mod_string;
}
// dump llvm ir to tmp file
inline void write_llvm_ir(llvm::Module *llvm_module, std::string filename = "", bool tracked = false)
{
// get module string
std::string module_string = print_llvm_module(llvm_module, false);
// get file name and path
if (filename.empty())
filename = llvm_module->getModuleIdentifier();
std::string count_str = "";
if (tracked)
{
count_str = "_" + std::to_string(print_count);
}
std::string ir_path = std::string("/tmp/") + filename + count_str + std::string(".ll");
// write file
std::ofstream output_file(ir_path);
output_file << module_string;
output_file.close();
// increament counter
if (tracked)
{
print_count += 1;
}
}
inline void print_triton_ir(triton::ir::module ir_ref, std::string name)
{
std::ofstream ir_out(std::string("/tmp/") + name + std::string("_") + return_current_time_and_date() + std::string(".ttir"));
ir_out.flush();
triton::ir::print(ir_ref, ir_out);
ir_out.close();
}
inline void print_triton_ir(std::string ir_ref, std::string name)
{
std::ofstream ir_out(std::string("/tmp/") + name + std::string("_") + return_current_time_and_date() + std::string(".ttir"));
ir_out.flush();
ir_out << ir_ref << std::endl;
ir_out.close();
}
inline std::string get_llvm_value_as_str(llvm::Value *llvm_value)
{
std::string value_str;
llvm::raw_string_ostream rso(value_str);
llvm_value->print(rso);
return rso.str();
}
inline void print_llvm_value(llvm::Value *llvm_value, std::string name = "")
{
if (llvm_value)
std::cout << "\t" << name << ": " << get_llvm_value_as_str(llvm_value) << std::endl;
else
std::cout << "\t" << name << ": "
<< "is nullptr" << std::endl;
}
inline void print_llvm_type(llvm::Type *llvm_type, std::string name = "")
{
std::string type_str;
llvm::raw_string_ostream rso(type_str);
llvm_type->print(rso);
std::cout << name << " type: " << rso.str() << std::endl;
}
inline void print_llvm_value_type(llvm::Value *llvm_value, std::string name = "")
{
print_llvm_type(llvm_value->getType(), name);
}
inline void write_ptx(std::string ptx_str)
{
std::ofstream file("/tmp/kernel.ptx");
file << ptx_str;
}
#endif

View File

@@ -36,6 +36,7 @@ namespace triton{
namespace codegen{
class nvidia_cu_target;
class amd_cl_target;
class target {
public:
@@ -49,7 +50,12 @@ public:
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
virtual unsigned guaranteed_alignment() = 0;
#ifdef USE_ROCM
amd_cl_target* as_nvidia();
amd_cl_target* as_amd();
#else
nvidia_cu_target* as_nvidia();
#endif
bool is_gpu() const;
private:
@@ -67,6 +73,7 @@ public:
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 16; }
int sm() { return 0; } // treat as if old CUDA device
};
class nvidia_cu_target: public target {

View File

@@ -11,7 +11,6 @@
#include "triton/external/CUDA/nvml.h"
//// HIP backend
//#define __HIP_PLATFORM_AMD__
#include "triton/external/hip.h"
//Exceptions
@@ -183,7 +182,8 @@ public:
static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart, hipEvent_t hEnd);
static hipError_t hipEventRecord(hipEvent_t hEvent, hipStream_t hStream);
static hipError_t hipEventDestroy(hipEvent_t hEvent);
// error handling
static hipError_t hipGetLastError(void);
private:
@@ -309,6 +309,8 @@ private:
static void* hipEventElapsedTime_;
static void* hipEventRecord_;
static void* hipEventDestroy_;
// error handling
static void* hipGetLastError_;
};
}

View File

@@ -13,8 +13,12 @@ std::string path_to_ptxas(int& version);
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
std::tuple<std::string, std::string> llir_to_amdgcn(llvm::Module* module, const std::string& proc);
hipModule_t amdgpu_to_hipmodule(const std::string& path);
}
}
#define STRINGIFY_HELPER(X) #X
#define STRINGIFY(X) STRINGIFY_HELPER(X)

View File

@@ -818,7 +818,7 @@ typedef enum CUcomputemode_enum {
* Memory advise values
*/
typedef enum CUmem_advise_enum {
CU_MEM_ADVISE_SET_READ_MOSTLY = 1, /**< Data will mostly be read and only occasionally be written to */
CU_MEM_ADVISE_SET_READ_MOSTLY = 1, /**< Data will mostly be read and only occassionally be written to */
CU_MEM_ADVISE_UNSET_READ_MOSTLY = 2, /**< Undo the effect of ::CU_MEM_ADVISE_SET_READ_MOSTLY */
CU_MEM_ADVISE_SET_PREFERRED_LOCATION = 3, /**< Set the preferred location for the data as the specified device */
CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION = 4, /**< Clear the preferred location for the data */
@@ -827,7 +827,7 @@ typedef enum CUmem_advise_enum {
} CUmem_advise;
typedef enum CUmem_range_attribute_enum {
CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY = 1, /**< Whether the range will mostly be read and only occasionally be written to */
CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY = 1, /**< Whether the range will mostly be read and only occassionally be written to */
CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION = 2, /**< The preferred location of the range */
CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY = 3, /**< Memory range has ::CU_MEM_ADVISE_SET_ACCESSED_BY set for specified device */
CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION = 4 /**< The last location to which the range was prefetched */
@@ -849,7 +849,7 @@ typedef enum CUjit_option_enum
* IN: Specifies minimum number of threads per block to target compilation
* for\n
* OUT: Returns the number of threads the compiler actually targeted.
* This restricts the resource utilization of the compiler (e.g. max
* This restricts the resource utilization fo the compiler (e.g. max
* registers) such that a block with the given number of threads should be
* able to launch based on register limitations. Note, this option does not
* currently take into account any other resource limitations, such as
@@ -974,10 +974,10 @@ typedef enum CUjit_option_enum
CU_JIT_FAST_COMPILE,
/**
* Array of device symbol names that will be relocated to the corresponding
* Array of device symbol names that will be relocated to the corresponing
* host addresses stored in ::CU_JIT_GLOBAL_SYMBOL_ADDRESSES.\n
* Must contain ::CU_JIT_GLOBAL_SYMBOL_COUNT entries.\n
* When loading a device module, driver will relocate all encountered
* When loding a device module, driver will relocate all encountered
* unresolved symbols to the host addresses.\n
* It is only allowed to register symbols that correspond to unresolved
* global variables.\n
@@ -1194,7 +1194,7 @@ typedef enum CUlimit_enum {
* Resource types
*/
typedef enum CUresourcetype_enum {
CU_RESOURCE_TYPE_ARRAY = 0x00, /**< Array resource */
CU_RESOURCE_TYPE_ARRAY = 0x00, /**< Array resoure */
CU_RESOURCE_TYPE_MIPMAPPED_ARRAY = 0x01, /**< Mipmapped array resource */
CU_RESOURCE_TYPE_LINEAR = 0x02, /**< Linear resource */
CU_RESOURCE_TYPE_PITCH2D = 0x03 /**< Pitch 2D resource */
@@ -2914,9 +2914,9 @@ typedef struct CUmemAllocationProp_st {
CUmemLocation location;
/**
* Windows-specific POBJECT_ATTRIBUTES required when
* ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This object attributes structure
* ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This object atributes structure
* includes security attributes that define
* the scope of which exported allocations may be transferred to other
* the scope of which exported allocations may be tranferred to other
* processes. In all other cases, this field is required to be zero.
*/
void *win32HandleMetaData;
@@ -3036,7 +3036,7 @@ typedef struct CUmemPoolProps_st {
/**
* Windows-specific LPSECURITYATTRIBUTES required when
* ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This security attribute defines
* the scope of which exported allocations may be transferred to other
* the scope of which exported allocations may be tranferred to other
* processes. In all other cases, this field is required to be zero.
*/
void *win32SecurityAttributes;
@@ -3519,7 +3519,7 @@ CUresult CUDAAPI cuDeviceGet(CUdevice *device, int ordinal);
CUresult CUDAAPI cuDeviceGetCount(int *count);
/**
* \brief Returns an identifier string for the device
* \brief Returns an identifer string for the device
*
* Returns an ASCII string identifying the device \p dev in the NULL-terminated
* string pointed to by \p name. \p len specifies the maximum length of the
@@ -3556,7 +3556,7 @@ CUresult CUDAAPI cuDeviceGetName(char *name, int len, CUdevice dev);
* Note there is a later version of this API, ::cuDeviceGetUuid_v2. It will
* supplant this version in 12.0, which is retained for minor version compatibility.
*
* Returns 16-octets identifying the device \p dev in the structure
* Returns 16-octets identifing the device \p dev in the structure
* pointed by the \p uuid.
*
* \param uuid - Returned UUID
@@ -3586,7 +3586,7 @@ CUresult CUDAAPI cuDeviceGetUuid(CUuuid *uuid, CUdevice dev);
/**
* \brief Return an UUID for the device (11.4+)
*
* Returns 16-octets identifying the device \p dev in the structure
* Returns 16-octets identifing the device \p dev in the structure
* pointed by the \p uuid. If the device is in MIG mode, returns its
* MIG UUID which uniquely identifies the subscribed MIG compute instance.
*
@@ -3867,7 +3867,7 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements,
* supports native atomic operations.
* - ::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO: Ratio of single precision performance
* (in floating-point operations per second) to double precision performance.
* - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS: Device supports coherently accessing
* - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS: Device suppports coherently accessing
* pageable memory without calling cudaHostRegister on it.
* - ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS: Device can coherently access managed memory
* concurrently with the CPU.
@@ -3875,7 +3875,7 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements,
* - ::CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM: Device can access host registered
* memory at the same virtual address as the CPU.
* - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN: The maximum per block shared memory size
* supported on this device. This is the maximum value that can be opted into when using the cuFuncSetAttribute() call.
* suported on this device. This is the maximum value that can be opted into when using the cuFuncSetAttribute() call.
* For more details see ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES
* - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES: Device accesses pageable memory via the host's
* page tables.
@@ -4132,7 +4132,7 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuDeviceGetProperties(CUdevprop *prop, CUdevi
*
* \deprecated
*
* This function was deprecated as of CUDA 5.0 and its functionality superseded
* This function was deprecated as of CUDA 5.0 and its functionality superceded
* by ::cuDeviceGetAttribute().
*
* Returns in \p *major and \p *minor the major and minor revision numbers that
@@ -4962,10 +4962,10 @@ CUresult CUDAAPI cuCtxSynchronize(void);
* returned.
*
* - ::CU_LIMIT_MAX_L2_FETCH_GRANULARITY controls the L2 cache fetch granularity.
* Values can range from 0B to 128B. This is purely a performance hint and
* Values can range from 0B to 128B. This is purely a performence hint and
* it can be ignored or clamped depending on the platform.
*
* - ::CU_LIMIT_PERSISTING_L2_CACHE_SIZE controls size in bytes available for
* - ::CU_LIMIT_PERSISTING_L2_CACHE_SIZE controls size in bytes availabe for
* persisting L2 cache. This is purely a performance hint and it can be
* ignored or clamped depending on the platform.
*
@@ -6398,7 +6398,7 @@ CUresult CUDAAPI cuMemHostGetFlags(unsigned int *pFlags, void *p);
* ::cuStreamAttachMemAsync will be required to enable access on such devices.
*
* If the association is later changed via ::cuStreamAttachMemAsync to
* a single stream, the default association as specified during ::cuMemAllocManaged
* a single stream, the default association as specifed during ::cuMemAllocManaged
* is restored when that stream is destroyed. For __managed__ variables, the
* default association is always ::CU_MEM_ATTACH_GLOBAL. Note that destroying a
* stream is an asynchronous operation, and as a result, the change to default
@@ -9616,13 +9616,13 @@ CUresult CUDAAPI cuMemAddressFree(CUdeviceptr ptr, size_t size);
* \brief Create a CUDA memory handle representing a memory allocation of a given size described by the given properties
*
* This creates a memory allocation on the target device specified through the
* \p prop structure. The created allocation will not have any device or host
* \p prop strcuture. The created allocation will not have any device or host
* mappings. The generic memory \p handle for the allocation can be
* mapped to the address space of calling process via ::cuMemMap. This handle
* cannot be transmitted directly to other processes (see
* ::cuMemExportToShareableHandle). On Windows, the caller must also pass
* an LPSECURITYATTRIBUTE in \p prop to be associated with this handle which
* limits or allows access to this handle for a recipient process (see
* limits or allows access to this handle for a recepient process (see
* ::CUmemAllocationProp::win32HandleMetaData for more). The \p size of this
* allocation must be a multiple of the the value given via
* ::cuMemGetAllocationGranularity with the ::CU_MEM_ALLOC_GRANULARITY_MINIMUM
@@ -9660,7 +9660,7 @@ CUresult CUDAAPI cuMemCreate(CUmemGenericAllocationHandle *handle, size_t size,
* are unmapped and when all outstanding references to the handle (including it's
* shareable counterparts) are also released. The generic memory handle can be
* freed when there are still outstanding mappings made with this handle. Each
* time a recipient process imports a shareable handle, it needs to pair it with
* time a recepient process imports a shareable handle, it needs to pair it with
* ::cuMemRelease for the handle to be freed. If \p handle is not a valid handle
* the behavior is undefined.
*
@@ -10975,7 +10975,7 @@ CUresult CUDAAPI cuMemAdvise(CUdeviceptr devPtr, size_t count, CUmem_advise advi
* a GPU id or CU_DEVICE_CPU depending on whether the last location for prefetch was a GPU or the CPU
* respectively. If any page in the memory range was never explicitly prefetched or if all pages were not
* prefetched to the same location, CU_DEVICE_INVALID will be returned. Note that this simply returns the
* last location that the application requested to prefetch the memory range to. It gives no indication as to
* last location that the applicaton requested to prefetch the memory range to. It gives no indication as to
* whether the prefetch operation to that location has completed or even begun.
*
* \param data - A pointers to a memory location where the result
@@ -13561,7 +13561,7 @@ CUresult CUDAAPI cuLaunchCooperativeKernel(CUfunction f,
* All kernels launched must be identical with respect to the compiled code. Note that
* any __device__, __constant__ or __managed__ variables present in the module that owns
* the kernel launched on each device, are independently instantiated on every device.
* It is the application's responsibility to ensure these variables are initialized and
* It is the application's responsiblity to ensure these variables are initialized and
* used appropriately.
*
* The size of the grids as specified in blocks, the size of the blocks themselves

View File

@@ -328,7 +328,7 @@ typedef enum nvmlGpuLevel_enum
typedef enum nvmlGpuP2PStatus_enum
{
NVML_P2P_STATUS_OK = 0,
NVML_P2P_STATUS_CHIPSET_NOT_SUPPORTED,
NVML_P2P_STATUS_CHIPSET_NOT_SUPPORED,
NVML_P2P_STATUS_GPU_NOT_SUPPORTED,
NVML_P2P_STATUS_IOH_TOPOLOGY_NOT_SUPPORTED,
NVML_P2P_STATUS_DISABLED_BY_REGKEY,
@@ -736,7 +736,7 @@ typedef enum nvmlReturn_enum
NVML_ERROR_IN_USE = 19, //!< An operation cannot be performed because the GPU is currently in use
NVML_ERROR_MEMORY = 20, //!< Insufficient memory
NVML_ERROR_NO_DATA = 21, //!<No data
NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22, //!< The requested vgpu operation is not available on target device, because ECC is enabled
NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22, //!< The requested vgpu operation is not available on target device, becasue ECC is enabled
NVML_ERROR_UNKNOWN = 999 //!< An internal driver error occurred
} nvmlReturn_t;
@@ -1463,7 +1463,7 @@ typedef struct nvmlEncoderSessionInfo_st
*/
typedef enum nvmlFBCSessionType_enum
{
NVML_FBC_SESSION_TYPE_UNKNOWN = 0, //!< Unknown
NVML_FBC_SESSION_TYPE_UNKNOWN = 0, //!< Unknwon
NVML_FBC_SESSION_TYPE_TOSYS, //!< ToSys
NVML_FBC_SESSION_TYPE_CUDA, //!< Cuda
NVML_FBC_SESSION_TYPE_VID, //!< Vid
@@ -3678,10 +3678,10 @@ nvmlReturn_t DECLDIR nvmlDeviceGetEncoderStats (nvmlDevice_t device, unsigned in
* Retrieves information about active encoder sessions on a target device.
*
* An array of active encoder sessions is returned in the caller-supplied buffer pointed at by \a sessionInfos. The
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* array elememt count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* written to the buffer.
*
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlEncoderSessionInfo_t array required in \a sessionCount.
* To query the number of active encoder sessions, call this function with *sessionCount = 0. The code will return
* NVML_SUCCESS with number of active encoder sessions updated in *sessionCount.
@@ -3727,7 +3727,7 @@ nvmlReturn_t DECLDIR nvmlDeviceGetDecoderUtilization(nvmlDevice_t device, unsign
* For Maxwell &tm; or newer fully supported devices.
*
* @param device The identifier of the target device
* @param fbcStats Reference to nvmlFBCStats_t structure containing NvFBC stats
* @param fbcStats Reference to nvmlFBCStats_t structure contianing NvFBC stats
*
* @return
* - \ref NVML_SUCCESS if \a fbcStats is fetched
@@ -3742,10 +3742,10 @@ nvmlReturn_t DECLDIR nvmlDeviceGetFBCStats(nvmlDevice_t device, nvmlFBCStats_t *
* Retrieves information about active frame buffer capture sessions on a target device.
*
* An array of active encoder sessions is returned in the caller-supplied buffer pointed at by \a sessionInfo. The
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* array elememt count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* written to the buffer.
*
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlFBCSessionInfo_t array required in \a sessionCount.
* To query the number of active FBC sessions, call this function with *sessionCount = 0. The code will return
* NVML_SUCCESS with number of active FBC sessions updated in *sessionCount.
@@ -4208,7 +4208,7 @@ nvmlReturn_t DECLDIR nvmlDeviceGetRetiredPages(nvmlDevice_t device, nvmlPageReti
* The address information provided from this API is the hardware address of the page that was retired. Note
* that this does not match the virtual address used in CUDA, but will match the address information in XID 63
*
* \note nvmlDeviceGetRetiredPages_v2 adds an additional timestamps parameter to return the time of each page's
* \note nvmlDeviceGetRetiredPages_v2 adds an additional timestamps paramter to return the time of each page's
* retirement.
*
* For Kepler &tm; or newer fully supported devices.
@@ -4476,7 +4476,7 @@ nvmlReturn_t DECLDIR nvmlDeviceSetDriverModel(nvmlDevice_t device, nvmlDriverMod
* Set clocks that device will lock to.
*
* Sets the clocks that the device will be running at to the value in the range of minGpuClockMHz to maxGpuClockMHz.
* Setting this will supersede application clock values and take effect regardless if a cuda app is running.
* Setting this will supercede application clock values and take effect regardless if a cuda app is running.
* See /ref nvmlDeviceSetApplicationsClocks
*
* Can be used as a setting to request constant performance.
@@ -5297,7 +5297,7 @@ nvmlReturn_t DECLDIR nvmlDeviceSetVirtualizationMode(nvmlDevice_t device, nvmlGp
* pointed at by \a vgpuTypeIds. The element count of nvmlVgpuTypeId_t array is passed in \a vgpuCount, and \a vgpuCount
* is used to return the number of vGPU types written to the buffer.
*
* If the supplied buffer is not large enough to accommodate the vGPU type array, the function returns
* If the supplied buffer is not large enough to accomodate the vGPU type array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlVgpuTypeId_t array required in \a vgpuCount.
* To query the number of vGPU types supported for the GPU, call this function with *vgpuCount = 0.
* The code will return NVML_ERROR_INSUFFICIENT_SIZE, or NVML_SUCCESS if no vGPU types are supported.
@@ -5327,9 +5327,9 @@ nvmlReturn_t DECLDIR nvmlDeviceGetSupportedVgpus(nvmlDevice_t device, unsigned i
* can concurrently run on a device. For example, if only one vGPU type is allowed at a time on a device, then the creatable
* list will be restricted to whatever vGPU type is already running on the device.
*
* If the supplied buffer is not large enough to accommodate the vGPU type array, the function returns
* If the supplied buffer is not large enough to accomodate the vGPU type array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlVgpuTypeId_t array required in \a vgpuCount.
* To query the number of vGPU types creatable for the GPU, call this function with *vgpuCount = 0.
* To query the number of vGPU types createable for the GPU, call this function with *vgpuCount = 0.
* The code will return NVML_ERROR_INSUFFICIENT_SIZE, or NVML_SUCCESS if no vGPU types are creatable.
*
* @param device The identifier of the target device
@@ -5392,7 +5392,7 @@ nvmlReturn_t DECLDIR nvmlVgpuTypeGetName(nvmlVgpuTypeId_t vgpuTypeId, char *vgpu
*
* @param vgpuTypeId Handle to vGPU type
* @param deviceID Device ID and vendor ID of the device contained in single 32 bit value
* @param subsystemID subsystem ID and subsystem vendor ID of the device contained in single 32 bit value
* @param subsystemID Subsytem ID and subsytem vendor ID of the device contained in single 32 bit value
*
* @return
* - \ref NVML_SUCCESS successful completion
@@ -5516,10 +5516,10 @@ nvmlReturn_t DECLDIR nvmlVgpuTypeGetMaxInstances(nvmlDevice_t device, nvmlVgpuTy
* Retrieve the active vGPU instances on a device.
*
* An array of active vGPU instances is returned in the caller-supplied buffer pointed at by \a vgpuInstances. The
* array element count is passed in \a vgpuCount, and \a vgpuCount is used to return the number of vGPU instances
* array elememt count is passed in \a vgpuCount, and \a vgpuCount is used to return the number of vGPU instances
* written to the buffer.
*
* If the supplied buffer is not large enough to accommodate the vGPU instance array, the function returns
* If the supplied buffer is not large enough to accomodate the vGPU instance array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlVgpuInstance_t array required in \a vgpuCount.
* To query the number of active vGPU instances, call this function with *vgpuCount = 0. The code will return
* NVML_ERROR_INSUFFICIENT_SIZE, or NVML_SUCCESS if no vGPU Types are supported.
@@ -5702,7 +5702,7 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetFrameRateLimit(nvmlVgpuInstance_t vgpuIn
* @param encoderCapacity Reference to an unsigned int for the encoder capacity
*
* @return
* - \ref NVML_SUCCESS if \a encoderCapacity has been retrieved
* - \ref NVML_SUCCESS if \a encoderCapacity has been retrived
* - \ref NVML_ERROR_UNINITIALIZED if the library has not been successfully initialized
* - \ref NVML_ERROR_INVALID_ARGUMENT if \a vgpuInstance is 0, or \a encoderQueryType is invalid
* - \ref NVML_ERROR_NOT_FOUND if \a vgpuInstance does not match a valid active vGPU instance on the system
@@ -5863,10 +5863,10 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetEncoderStats(nvmlVgpuInstance_t vgpuInst
* Retrieves information about all active encoder sessions on a vGPU Instance.
*
* An array of active encoder sessions is returned in the caller-supplied buffer pointed at by \a sessionInfo. The
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* array elememt count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* written to the buffer.
*
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlEncoderSessionInfo_t array required in \a sessionCount.
* To query the number of active encoder sessions, call this function with *sessionCount = 0. The code will return
* NVML_SUCCESS with number of active encoder sessions updated in *sessionCount.
@@ -5896,7 +5896,7 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetEncoderSessions(nvmlVgpuInstance_t vgpuI
* For Maxwell &tm; or newer fully supported devices.
*
* @param vgpuInstance Identifier of the target vGPU instance
* @param fbcStats Reference to nvmlFBCStats_t structure containing NvFBC stats
* @param fbcStats Reference to nvmlFBCStats_t structure contianing NvFBC stats
*
* @return
* - \ref NVML_SUCCESS if \a fbcStats is fetched
@@ -5914,7 +5914,7 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetFBCStats(nvmlVgpuInstance_t vgpuInstance
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* written to the buffer.
*
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlFBCSessionInfo_t array required in \a sessionCount.
* To query the number of active FBC sessions, call this function with *sessionCount = 0. The code will return
* NVML_SUCCESS with number of active FBC sessions updated in *sessionCount.
@@ -6094,7 +6094,7 @@ typedef struct nvmlVgpuPgpuMetadata_st
unsigned int version; //!< Current version of the structure
unsigned int revision; //!< Current revision of the structure
char hostDriverVersion[NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE]; //!< Host driver version
unsigned int pgpuVirtualizationCaps; //!< Pgpu virtualization capabilities bitfield
unsigned int pgpuVirtualizationCaps; //!< Pgpu virtualizaion capabilities bitfileld
unsigned int reserved[7]; //!< Reserved for internal use
unsigned int opaqueDataSize; //!< Size of opaque data field in bytes
char opaqueData[4]; //!< Opaque data
@@ -6191,7 +6191,7 @@ nvmlReturn_t DECLDIR nvmlDeviceGetVgpuMetadata(nvmlDevice_t device, nvmlVgpuPgpu
*
* The caller passes in a buffer via \a compatibilityInfo, into which a compatibility information structure is written. The
* structure defines the states in which the vGPU / VM may be booted on the physical GPU. If the vGPU / VM compatibility
* with the physical GPU is limited, a limit code indicates the factor limiting compatibility.
* with the physical GPU is limited, a limit code indicates the factor limiting compability.
* (see \ref nvmlVgpuPgpuCompatibilityLimitCode_t for details).
*
* Note: vGPU compatibility does not take into account dynamic capacity conditions that may limit a system's ability to

View File

@@ -950,7 +950,7 @@ namespace half_float
/// Convert half-precision floating point to integer.
/// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding
/// \tparam E `true` for round to even, `false` for round away from zero
/// \tparam T type to convert to (builtin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \param value binary representation of half-precision value
/// \return integral value
template<std::float_round_style R,bool E,typename T> T half2int_impl(uint16 value)
@@ -988,13 +988,13 @@ namespace half_float
/// Convert half-precision floating point to integer.
/// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding
/// \tparam T type to convert to (builtin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \param value binary representation of half-precision value
/// \return integral value
template<std::float_round_style R,typename T> T half2int(uint16 value) { return half2int_impl<R,HALF_ROUND_TIES_TO_EVEN,T>(value); }
/// Convert half-precision floating point to integer using round-to-nearest-away-from-zero.
/// \tparam T type to convert to (builtin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \param value binary representation of half-precision value
/// \return integral value
template<typename T> T half2int_up(uint16 value) { return half2int_impl<std::round_to_nearest,0,T>(value); }
@@ -1053,7 +1053,7 @@ namespace half_float
/// Half-precision floating point type.
/// This class implements an IEEE-conformant half-precision floating point type with the usual arithmetic operators and
/// conversions. It is implicitly convertible to single-precision floating point, which makes arithmetic expressions and
/// conversions. It is implicitly convertible to single-precision floating point, which makes artihmetic expressions and
/// functions with mixed-type operands to be of the most precise operand type. Additionally all arithmetic operations
/// (and many mathematical functions) are carried out in single-precision internally. All conversions from single- to
/// half-precision are done using the library's default rounding mode, but temporary results inside chained arithmetic
@@ -1062,7 +1062,7 @@ namespace half_float
/// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's less strict and
/// extended definitions it is both a standard layout type and a trivially copyable type (even if not a POD type), which
/// means it can be standard-conformantly copied using raw binary copies. But in this context some more words about the
/// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not necessarily have to be of
/// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not neccessarily have to be of
/// exactly 16-bits size. But on any reasonable implementation the actual binary representation of this type will most
/// probably not ivolve any additional "magic" or padding beyond the simple binary representation of the underlying 16-bit
/// IEEE number, even if not strictly guaranteed by the standard. But even then it only has an actual size of 16 bits if
@@ -2181,7 +2181,7 @@ namespace half_float
/// Identity.
/// \param arg operand
/// \return unchanged operand
/// \return uncahnged operand
template<typename T> HALF_CONSTEXPR typename enable<T,T>::type operator+(T arg) { return arg; }
/// Negation.
@@ -2620,7 +2620,7 @@ namespace half_float
/// Multiply by power of two.
/// \param arg number to modify
/// \param exp power of two to multiply with
/// \return \a arg multiplied by 2 raised to \a exp
/// \return \a arg multplied by 2 raised to \a exp
// template<typename T> typename enable<half,T>::type ldexp(T arg, int exp) { return functions::scalbln(arg, exp); }
inline half ldexp(half arg, int exp) { return functions::scalbln(arg, exp); }
inline half ldexp(expr arg, int exp) { return functions::scalbln(arg, exp); }
@@ -2636,7 +2636,7 @@ namespace half_float
/// Multiply by power of two.
/// \param arg number to modify
/// \param exp power of two to multiply with
/// \return \a arg multiplied by 2 raised to \a exp
/// \return \a arg multplied by 2 raised to \a exp
// template<typename T> typename enable<half,T>::type scalbn(T arg, int exp) { return functions::scalbln(arg, exp); }
inline half scalbn(half arg, int exp) { return functions::scalbln(arg, exp); }
inline half scalbn(expr arg, int exp) { return functions::scalbln(arg, exp); }
@@ -2644,7 +2644,7 @@ namespace half_float
/// Multiply by power of two.
/// \param arg number to modify
/// \param exp power of two to multiply with
/// \return \a arg multiplied by 2 raised to \a exp
/// \return \a arg multplied by 2 raised to \a exp
// template<typename T> typename enable<half,T>::type scalbln(T arg, long exp) { return functions::scalbln(arg, exp); }
inline half scalbln(half arg, long exp) { return functions::scalbln(arg, exp); }
inline half scalbln(expr arg, long exp) { return functions::scalbln(arg, exp); }

View File

@@ -1,13 +1,35 @@
/*
* @brief hipError_t
* @enum
* @ingroup Enumerations
*/
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
Copyright (c) 2015 - 2022 Advanced Micro Devices, Inc. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
#ifndef HIP_H
#define HIP_H
// Ignoring error-code return values from hip APIs is discouraged. On C++17,
// we can make that yield a warning
#if __cplusplus >= 201703L
#define __HIP_NODISCARD [[nodiscard]]
#else
#define __HIP_NODISCARD
#endif
/*
* @brief hipError_t
@@ -17,9 +39,7 @@
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
#include <cstddef>
typedef enum hipError_t {
typedef enum __HIP_NODISCARD hipError_t {
hipSuccess = 0, ///< Successful completion.
hipErrorInvalidValue = 1, ///< One or more of the parameters passed to the API call is NULL
///< or not in an acceptable range.
@@ -73,6 +93,7 @@ typedef enum hipError_t {
hipErrorInvalidHandle = 400,
// Deprecated
hipErrorInvalidResourceHandle = 400, ///< Resource handle (hipEvent_t or hipStream_t) invalid.
hipErrorIllegalState = 401, ///< Resource required is not in a valid state to perform operation.
hipErrorNotFound = 500,
hipErrorNotReady = 600, ///< Indicates that asynchronous operations enqueued earlier are not
///< ready. This is not actually an error, but is used to distinguish
@@ -86,6 +107,7 @@ typedef enum hipError_t {
hipErrorPeerAccessNotEnabled =
705, ///< Peer access was never enabled from the current device.
hipErrorSetOnActiveProcess = 708,
hipErrorContextIsDestroyed = 709,
hipErrorAssert = 710, ///< Produced when the kernel calls assert.
hipErrorHostMemoryAlreadyRegistered =
712, ///< Produced when trying to lock a page-locked memory.
@@ -98,6 +120,32 @@ typedef enum hipError_t {
///< that was launched via cooperative launch APIs exceeds the maximum number of
///< allowed blocks for the current device
hipErrorNotSupported = 801, ///< Produced when the hip API is not supported/implemented
hipErrorStreamCaptureUnsupported = 900, ///< The operation is not permitted when the stream
///< is capturing.
hipErrorStreamCaptureInvalidated = 901, ///< The current capture sequence on the stream
///< has been invalidated due to a previous error.
hipErrorStreamCaptureMerge = 902, ///< The operation would have resulted in a merge of
///< two independent capture sequences.
hipErrorStreamCaptureUnmatched = 903, ///< The capture was not initiated in this stream.
hipErrorStreamCaptureUnjoined = 904, ///< The capture sequence contains a fork that was not
///< joined to the primary stream.
hipErrorStreamCaptureIsolation = 905, ///< A dependency would have been created which crosses
///< the capture sequence boundary. Only implicit
///< in-stream ordering dependencies are allowed
///< to cross the boundary
hipErrorStreamCaptureImplicit = 906, ///< The operation would have resulted in a disallowed
///< implicit dependency on a current capture sequence
///< from hipStreamLegacy.
hipErrorCapturedEvent = 907, ///< The operation is not permitted on an event which was last
///< recorded in a capturing stream.
hipErrorStreamCaptureWrongThread = 908, ///< A stream capture sequence not initiated with
///< the hipStreamCaptureModeRelaxed argument to
///< hipStreamBeginCapture was passed to
///< hipStreamEndCapture in a different thread.
hipErrorGraphExecUpdateFailure = 910, ///< This error indicates that the graph update
///< not performed because it included changes which
///< violated constraintsspecific to instantiated graph
///< update.
hipErrorUnknown = 999, //< Unknown error.
// HSA Runtime Error Codes start here.
hipErrorRuntimeMemory = 1052, ///< HSA runtime memory call returned error. Typically not seen
@@ -107,35 +155,154 @@ typedef enum hipError_t {
hipErrorTbd ///< Marker that more error codes are needed.
} hipError_t;
#undef __HIP_NODISCARD
/*
* @brief hipDeviceAttribute_t
* @enum
* @ingroup Enumerations
*/
typedef enum hipDeviceAttribute_t {
hipDeviceAttributeCudaCompatibleBegin = 0,
hipDeviceAttributeEccEnabled = hipDeviceAttributeCudaCompatibleBegin, ///< Whether ECC support is enabled.
hipDeviceAttributeAccessPolicyMaxWindowSize, ///< Cuda only. The maximum size of the window policy in bytes.
hipDeviceAttributeAsyncEngineCount, ///< Cuda only. Asynchronous engines number.
hipDeviceAttributeCanMapHostMemory, ///< Whether host memory can be mapped into device address space
hipDeviceAttributeCanUseHostPointerForRegisteredMem,///< Cuda only. Device can access host registered memory
///< at the same virtual address as the CPU
hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz.
hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in.
hipDeviceAttributeComputePreemptionSupported, ///< Cuda only. Device supports Compute Preemption.
hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple kernels concurrently.
hipDeviceAttributeConcurrentManagedAccess, ///< Device can coherently access managed memory concurrently with the CPU
hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch
hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative launch on multiple devices
hipDeviceAttributeDeviceOverlap, ///< Cuda only. Device can concurrently copy memory and execute a kernel.
///< Deprecated. Use instead asyncEngineCount.
hipDeviceAttributeDirectManagedMemAccessFromHost, ///< Host can directly access managed memory on
///< the device without migration
hipDeviceAttributeGlobalL1CacheSupported, ///< Cuda only. Device supports caching globals in L1
hipDeviceAttributeHostNativeAtomicSupported, ///< Cuda only. Link between the device and the host supports native atomic operations
hipDeviceAttributeIntegrated, ///< Device is integrated GPU
hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices.
hipDeviceAttributeKernelExecTimeout, ///< Run time limit for kernels executed on the device
hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device doesn't have L2 cache.
hipDeviceAttributeLocalL1CacheSupported, ///< caching locals in L1 is supported
hipDeviceAttributeLuid, ///< Cuda only. 8-byte locally unique identifier in 8 bytes. Undefined on TCC and non-Windows platforms
hipDeviceAttributeLuidDeviceNodeMask, ///< Cuda only. Luid device node mask. Undefined on TCC and non-Windows platforms
hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability version number.
hipDeviceAttributeManagedMemory, ///< Device supports allocating managed memory on this system
hipDeviceAttributeMaxBlocksPerMultiProcessor, ///< Cuda only. Max block size per multiprocessor
hipDeviceAttributeMaxBlockDimX, ///< Max block size in width.
hipDeviceAttributeMaxBlockDimY, ///< Max block size in height.
hipDeviceAttributeMaxBlockDimZ, ///< Max block size in depth.
hipDeviceAttributeMaxGridDimX, ///< Max grid size in width.
hipDeviceAttributeMaxGridDimY, ///< Max grid size in height.
hipDeviceAttributeMaxGridDimZ, ///< Max grid size in depth.
hipDeviceAttributeMaxSurface1D, ///< Maximum size of 1D surface.
hipDeviceAttributeMaxSurface1DLayered, ///< Cuda only. Maximum dimensions of 1D layered surface.
hipDeviceAttributeMaxSurface2D, ///< Maximum dimension (width, height) of 2D surface.
hipDeviceAttributeMaxSurface2DLayered, ///< Cuda only. Maximum dimensions of 2D layered surface.
hipDeviceAttributeMaxSurface3D, ///< Maximum dimension (width, height, depth) of 3D surface.
hipDeviceAttributeMaxSurfaceCubemap, ///< Cuda only. Maximum dimensions of Cubemap surface.
hipDeviceAttributeMaxSurfaceCubemapLayered, ///< Cuda only. Maximum dimension of Cubemap layered surface.
hipDeviceAttributeMaxTexture1DWidth, ///< Maximum size of 1D texture.
hipDeviceAttributeMaxTexture1DLayered, ///< Cuda only. Maximum dimensions of 1D layered texture.
hipDeviceAttributeMaxTexture1DLinear, ///< Maximum number of elements allocatable in a 1D linear texture.
///< Use cudaDeviceGetTexture1DLinearMaxWidth() instead on Cuda.
hipDeviceAttributeMaxTexture1DMipmap, ///< Cuda only. Maximum size of 1D mipmapped texture.
hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D texture.
hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension hight of 2D texture.
hipDeviceAttributeMaxTexture2DGather, ///< Cuda only. Maximum dimensions of 2D texture if gather operations performed.
hipDeviceAttributeMaxTexture2DLayered, ///< Cuda only. Maximum dimensions of 2D layered texture.
hipDeviceAttributeMaxTexture2DLinear, ///< Cuda only. Maximum dimensions (width, height, pitch) of 2D textures bound to pitched memory.
hipDeviceAttributeMaxTexture2DMipmap, ///< Cuda only. Maximum dimensions of 2D mipmapped texture.
hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D texture.
hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimension height of 3D texture.
hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimension depth of 3D texture.
hipDeviceAttributeMaxTexture3DAlt, ///< Cuda only. Maximum dimensions of alternate 3D texture.
hipDeviceAttributeMaxTextureCubemap, ///< Cuda only. Maximum dimensions of Cubemap texture
hipDeviceAttributeMaxTextureCubemapLayered, ///< Cuda only. Maximum dimensions of Cubemap layered texture.
hipDeviceAttributeMaxThreadsDim, ///< Maximum dimension of a block
hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per block.
hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads per multiprocessor.
hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory copies
hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits.
hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in kilohertz.
hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability version number.
hipDeviceAttributeMultiGpuBoardGroupID, ///< Cuda only. Unique ID of device group on the same multi-GPU board
hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the device.
hipDeviceAttributeName, ///< Device name.
hipDeviceAttributePageableMemoryAccess, ///< Device supports coherently accessing pageable memory
///< without calling hipHostRegister on it
hipDeviceAttributePageableMemoryAccessUsesHostPageTables, ///< Device accesses pageable memory via the host's page tables
hipDeviceAttributePciBusId, ///< PCI Bus ID.
hipDeviceAttributePciDeviceId, ///< PCI Device ID.
hipDeviceAttributePciDomainID, ///< PCI Domain ID.
hipDeviceAttributePersistingL2CacheMaxSize, ///< Cuda11 only. Maximum l2 persisting lines capacity in bytes
hipDeviceAttributeMaxRegistersPerBlock, ///< 32-bit registers available to a thread block. This number is shared
///< by all thread blocks simultaneously resident on a multiprocessor.
hipDeviceAttributeMaxRegistersPerMultiprocessor, ///< 32-bit registers available per block.
hipDeviceAttributeReservedSharedMemPerBlock, ///< Cuda11 only. Shared memory reserved by CUDA driver per block.
hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory available per block in bytes.
hipDeviceAttributeSharedMemPerBlockOptin, ///< Cuda only. Maximum shared memory per block usable by special opt in.
hipDeviceAttributeSharedMemPerMultiprocessor, ///< Cuda only. Shared memory available per multiprocessor.
hipDeviceAttributeSingleToDoublePrecisionPerfRatio, ///< Cuda only. Performance ratio of single precision to double precision.
hipDeviceAttributeStreamPrioritiesSupported, ///< Cuda only. Whether to support stream priorities.
hipDeviceAttributeSurfaceAlignment, ///< Cuda only. Alignment requirement for surfaces
hipDeviceAttributeTccDriver, ///< Cuda only. Whether device is a Tesla device using TCC driver
hipDeviceAttributeTextureAlignment, ///< Alignment requirement for textures
hipDeviceAttributeTexturePitchAlignment, ///< Pitch alignment requirement for 2D texture references bound to pitched memory;
hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes.
hipDeviceAttributeTotalGlobalMem, ///< Global memory available on devicice.
hipDeviceAttributeUnifiedAddressing, ///< Cuda only. An unified address space shared with the host.
hipDeviceAttributeUuid, ///< Cuda only. Unique ID in 16 byte.
hipDeviceAttributeWarpSize, ///< Warp size in threads.
hipDeviceAttributeMemoryPoolsSupported, ///< Device supports HIP Stream Ordered Memory Allocator
hipDeviceAttributeCudaCompatibleEnd = 9999,
hipDeviceAttributeAmdSpecificBegin = 10000,
hipDeviceAttributeClockInstructionRate = hipDeviceAttributeAmdSpecificBegin, ///< Frequency in khz of the timer used by the device-side "clock*"
hipDeviceAttributeArch, ///< Device architecture
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory PerMultiprocessor.
hipDeviceAttributeGcnArch, ///< Device gcn architecture
hipDeviceAttributeGcnArchName, ///< Device gcnArch name in 256 bytes
hipDeviceAttributeHdpMemFlushCntl, ///< Address of the HDP_MEM_COHERENCY_FLUSH_CNTL register
hipDeviceAttributeHdpRegFlushCntl, ///< Address of the HDP_REG_COHERENCY_FLUSH_CNTL register
hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc, ///< Supports cooperative launch on multiple
///< devices with unmatched functions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim, ///< Supports cooperative launch on multiple
///< devices with unmatched grid dimensions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim, ///< Supports cooperative launch on multiple
///< devices with unmatched block dimensions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem, ///< Supports cooperative launch on multiple
///< devices with unmatched shared memories
hipDeviceAttributeIsLargeBar, ///< Whether it is LargeBar
hipDeviceAttributeAsicRevision, ///< Revision of the GPU in this device
hipDeviceAttributeCanUseStreamWaitValue, ///< '1' if Device supports hipStreamWaitValue32() and
///< hipStreamWaitValue64(), '0' otherwise.
hipDeviceAttributeImageSupport, ///< '1' if Device supports image, '0' otherwise.
hipDeviceAttributePhysicalMultiProcessorCount, ///< All available physical compute
///< units for the device
hipDeviceAttributeFineGrainSupport, ///< '1' if Device supports fine grain, '0' otherwise
hipDeviceAttributeAmdSpecificEnd = 19999,
hipDeviceAttributeVendorSpecificBegin = 20000,
// Extended attributes for vendors
} hipDeviceAttribute_t;
#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01)
#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02)
#define HIP_LAUNCH_PARAM_END ((void*)0x03)
// API-visible structures
typedef struct ihipCtx_t* hipCtx_t;
// Note many APIs also use integer deviceIds as an alternative to the device pointer:
typedef int hipDevice_t;
typedef enum hipDeviceP2PAttr {
hipDevP2PAttrPerformanceRank = 0,
hipDevP2PAttrAccessSupported,
hipDevP2PAttrNativeAtomicSupported,
hipDevP2PAttrHipArrayAccessSupported
} hipDeviceP2PAttr;
typedef struct ihipStream_t* hipStream_t;
#define hipIpcMemLazyEnablePeerAccess 0
#define HIP_IPC_HANDLE_SIZE 64
typedef struct hipIpcMemHandle_st {
char reserved[HIP_IPC_HANDLE_SIZE];
} hipIpcMemHandle_t;
typedef struct hipIpcEventHandle_st {
char reserved[HIP_IPC_HANDLE_SIZE];
} hipIpcEventHandle_t;
typedef struct ihipModule_t* hipModule_t;
typedef struct ihipModuleSymbol_t* hipFunction_t;
typedef struct hipFuncAttributes {
@@ -150,91 +317,8 @@ typedef struct hipFuncAttributes {
int ptxVersion;
size_t sharedSizeBytes;
} hipFuncAttributes;
typedef struct ihipEvent_t* hipEvent_t;
/*
* @brief hipDeviceAttribute_t
* @enum
* @ingroup Enumerations
*/
typedef enum hipDeviceAttribute_t {
hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per block.
hipDeviceAttributeMaxBlockDimX, ///< Maximum x-dimension of a block.
hipDeviceAttributeMaxBlockDimY, ///< Maximum y-dimension of a block.
hipDeviceAttributeMaxBlockDimZ, ///< Maximum z-dimension of a block.
hipDeviceAttributeMaxGridDimX, ///< Maximum x-dimension of a grid.
hipDeviceAttributeMaxGridDimY, ///< Maximum y-dimension of a grid.
hipDeviceAttributeMaxGridDimZ, ///< Maximum z-dimension of a grid.
hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory available per block in
///< bytes.
hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes.
hipDeviceAttributeWarpSize, ///< Warp size in threads.
hipDeviceAttributeMaxRegistersPerBlock, ///< Maximum number of 32-bit registers available to a
///< thread block. This number is shared by all thread
///< blocks simultaneously resident on a
///< multiprocessor.
hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz.
hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in kilohertz.
hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits.
hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the device.
hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in.
hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device doesn't have L2
///< cache.
hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads per
///< multiprocessor.
hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability version number.
hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability version number.
hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple kernels
///< concurrently.
hipDeviceAttributePciBusId, ///< PCI Bus ID.
hipDeviceAttributePciDeviceId, ///< PCI Device ID.
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory Per
///< Multiprocessor.
hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices.
hipDeviceAttributeIntegrated, ///< iGPU
hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch
hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative launch on multiple devices
hipDeviceAttributeMaxTexture1DWidth, ///< Maximum number of elements in 1D images
hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D images in image elements
hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension height of 2D images in image elements
hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D images in image elements
hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimensions height of 3D images in image elements
hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimensions depth of 3D images in image elements
hipDeviceAttributeHdpMemFlushCntl, ///< Address of the HDP_MEM_COHERENCY_FLUSH_CNTL register
hipDeviceAttributeHdpRegFlushCntl, ///< Address of the HDP_REG_COHERENCY_FLUSH_CNTL register
hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory copies
hipDeviceAttributeTextureAlignment, ///<Alignment requirement for textures
hipDeviceAttributeTexturePitchAlignment, ///<Pitch alignment requirement for 2D texture references bound to pitched memory;
hipDeviceAttributeKernelExecTimeout, ///<Run time limit for kernels executed on the device
hipDeviceAttributeCanMapHostMemory, ///<Device can map host memory into device address space
hipDeviceAttributeEccEnabled, ///<Device has ECC support enabled
hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc, ///< Supports cooperative launch on multiple
///devices with unmatched functions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim, ///< Supports cooperative launch on multiple
///devices with unmatched grid dimensions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim, ///< Supports cooperative launch on multiple
///devices with unmatched block dimensions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem, ///< Supports cooperative launch on multiple
///devices with unmatched shared memories
hipDeviceAttributeAsicRevision, ///< Revision of the GPU in this device
hipDeviceAttributeManagedMemory, ///< Device supports allocating managed memory on this system
hipDeviceAttributeDirectManagedMemAccessFromHost, ///< Host can directly access managed memory on
/// the device without migration
hipDeviceAttributeConcurrentManagedAccess, ///< Device can coherently access managed memory
/// concurrently with the CPU
hipDeviceAttributePageableMemoryAccess, ///< Device supports coherently accessing pageable memory
/// without calling hipHostRegister on it
hipDeviceAttributePageableMemoryAccessUsesHostPageTables, ///< Device accesses pageable memory via
/// the host's page tables
hipDeviceAttributeCanUseStreamWaitValue ///< '1' if Device supports hipStreamWaitValue32() and
///< hipStreamWaitValue64() , '0' otherwise.
} hipDeviceAttribute_t;
typedef void* hipDeviceptr_t;
/*
@@ -262,7 +346,6 @@ typedef enum hipJitOption {
hipJitOptionFastCompile,
hipJitOptionNumOptions
} hipJitOption;
/**
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
*/
@@ -271,7 +354,6 @@ typedef enum hipFuncAttribute {
hipFuncAttributePreferredSharedMemoryCarveout = 9,
hipFuncAttributeMax
} hipFuncAttribute;
/**
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
*/
@@ -282,7 +364,4 @@ typedef enum hipFuncCache_t {
hipFuncCachePreferEqual, ///< prefer equal size L1 cache and shared memory
} hipFuncCache_t;
#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01)
#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02)
#define HIP_LAUNCH_PARAM_END ((void*)0x03)
#endif

View File

@@ -45,6 +45,7 @@ public:
value *get_int64(uint64_t val);
value *get_float16(float val);
value *get_float32(float val);
value *get_float64(float val);
value *get_range(int32_t lo, int32_t hi);
// Types
type *get_void_ty();

View File

@@ -196,7 +196,7 @@ mma_layout::mma_layout(size_t num_warps,
tensor_core_type_ = get_mma_type(dot);
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
if(tgt->as_nvidia()->sm() < 80){
if(tgt->as_nvidia() && tgt->as_nvidia()->sm() < 80){
fpw_ = {2, 2, 1};
auto ord_a = layout_a->get_order();
auto ord_b = layout_b->get_order();

View File

@@ -79,11 +79,13 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
ir::module& ir, llvm::LLVMContext& ctx, codegen::target* target,
int num_warps, int num_stages, int& shared_static,
const ExternLibMap& extern_lib_map) {
std::cout << "pass.cc: add_passes_to_emit_bin" << std::endl;
// generate llvm code
std::string name = ir.get_function_list()[0]->get_name();
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
// optimizations
bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
// bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
bool has_sm80 = false;
// create passes
codegen::analysis::align align;
codegen::transform::inliner inliner;
@@ -149,13 +151,6 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
// ir.print(std::cout);
isel.visit(ir, *llvm);
shared_static = allocation.allocated_size();
if (target->as_nvidia() && target->as_nvidia()->sm() < 70) {
// sm < 70 (Pascal) has little shared memory resource.
// Instead of having "Error: Invalid argument" on launching a kernel, let's throw an error here.
if (shared_static >= 65536) {
throw std::runtime_error("Device does not support shared memory of " + std::to_string(shared_static) + "bytes");
}
}
if (isel.get_extern_lib_map().size() > 0) {
// If there's any extern lib calls,

View File

@@ -16,7 +16,13 @@
#include "triton/ir/utils.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IR/Type.h"
#ifdef USE_ROCM
#include "llvm/IR/IntrinsicsAMDGPU.h"
#else
#include "llvm/IR/IntrinsicsNVPTX.h"
#endif
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/InlineAsm.h"
@@ -91,6 +97,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
#define f16_ty builder_->getHalfTy()
#define bf16_ty builder_->getInt16Ty()
#define f32_ty builder_->getFloatTy()
#define f64_ty builder_->getDoubleTy()
#define i1_ty builder_->getInt1Ty()
#define i8_ty builder_->getInt8Ty()
#define i16_ty builder_->getInt16Ty()
@@ -410,48 +417,44 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
for(indices_t idx: idxs_.at(x)){
Value *lhs = vals_[x->get_operand(0)][idx];
Value *rhs = vals_[x->get_operand(1)][idx];
// manually select bf16 bin op
if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) {
assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty());
if (x->get_op() == tt::FAdd) { // a + b = a * 1.0 + b
if (x->get_op() == tt::FAdd) {
InlineAsm *bf16_add_asm =
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
"{ .reg .b16 c; \n\t"
" mov.b16 c, 0x3f80U; \n\t" // 1.0
" fma.rn.bf16 $0, $1, c, $2; } \n\t",
"=h,h,h", false);
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
"v_lshlrev_b32 $1, 16, $1 "
"v_lshlrev_b32 $2, 16, $2 "
"v_add_f32 $0, $1, $2 "
"v_lshrrev_b32 $0, 16, $0 ",
"=v,v,v", false);
vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs});
} else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a
InlineAsm *bf16_sub_asm =
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
" { .reg .b16 c; \n\t"
" mov.b16 c, 0xbf80U; \n\t" // -1.0
" fma.rn.bf16 $0, $2, c, $1;} \n\t",
"=h,h,h", false);
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
"v_lshlrev_b32 $1, 16, $1 "
"v_lshlrev_b32 $2, 16, $2 "
"v_sub_f32 $0, $1, $2 "
"v_lshrrev_b32 $0, 16, $0 ",
"=v,v,v", false);
vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs});
} else if (x->get_op() == tt::FMul) { // a * b = a*b + 0
InlineAsm *bf16_mul_asm =
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
" { .reg .b16 c; \n\t"
" mov.b16 c, 0x8000U; \n\t" // 0.0
" fma.rn.bf16 $0, $1, $2, c;} \n\t",
"=h,h,h", false);
"v_lshlrev_b32 $1, 16, $1 "
"v_lshlrev_b32 $2, 16, $2 "
"v_mul_f32 $0, $1, $2 "
"v_lshrrev_b32 $0, 16, $0 ",
"=v,v,v", false);
vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs});
} else
throw std::runtime_error("invalid bin op for bf16");
} else { // not bf16
}
else {
auto op = cvt(x->get_op());
if(op == ll::Add)
vals_[x][idx] = add(lhs, rhs);
else if(op == ll::Mul)
vals_[x][idx] = mul(lhs, rhs);
else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() &&
x->get_type()->get_scalar_ty()->is_fp32_ty()){
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false),
" div.full.f32 $0, $1, $2;", "=r,r,r", false);
vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs});
}
else
vals_[x][idx] = bin_op(op, lhs, rhs);
}
@@ -754,7 +757,7 @@ Value* generator::bf16_to_fp32(Value *in0){
}
Value* generator::fp32_to_bf16(Value *in0){
if(tgt_->as_nvidia()->sm() >= 80){
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80){
InlineAsm *ptx = InlineAsm::get(FunctionType::get(bf16_ty, {f32_ty}, false),
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
return call(ptx, {in0});
@@ -1120,6 +1123,22 @@ void generator::visit_load_inst(ir::load_inst* x){
ir::value *op = x->get_pointer_operand();
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
#ifdef USE_ROCM
// code generation
auto idxs = idxs_.at(x);
for(size_t i = 0; i <idxs.size(); i += 1){
indices_t idx = idxs[i];
// pointer value
Value *ptr = vals_[op][idx];
// create load
Value *_ret = builder_->CreateLoad(ty, ptr);
// upload to global vals map
vals_[x][idx] = _ret;
}
#else
// compute vector width
size_t vec = 1;
bool is_mma_first_row = false;
@@ -1298,6 +1317,7 @@ void generator::visit_load_inst(ir::load_inst* x){
for(size_t ii = 0; ii < vec; ii++)
vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp);
}
#endif
}
void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
@@ -1316,6 +1336,23 @@ void generator::visit_store_inst(ir::store_inst * x){
// operands
ir::value *ptr_op = x->get_pointer_operand();
ir::value *val_op = x->get_value_operand();
#ifdef USE_ROCM
auto idxs = idxs_.at(val_op);
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
for (size_t i = 0; i < idxs.size(); i += 1)
{
auto idx = idxs[i];
// pointer
Value *ptr = vals_[ptr_op][idx];
// value
Value *val = vals_.at(val_op)[idxs[i]];
// store value at pointer
store(val, ptr);
}
#else
ir::value *msk_op = nullptr;
if(auto* msk_st = dynamic_cast<ir::masked_store_inst*>(x))
msk_op = msk_st->get_mask_operand();
@@ -1431,6 +1468,7 @@ void generator::visit_store_inst(ir::store_inst * x){
args.push_back(policies_.at(x->get_eviction_policy()));
call(_asm, args);
}
#endif
}
void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) {
visit_store_inst(x);
@@ -1549,7 +1587,12 @@ void generator::visit_exp_inst(ir::exp_inst* x){
Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
std::vector<llvm::Type*> tys = {f32_ty};
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
#ifdef USE_ROCM
llvm::Function *ex2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::exp2, tys);
#else
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
#endif
for(auto idx: idxs_.at(x)){
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
// Value *ex2arg = vals_[x->get_operand(0)][idx];
@@ -1563,7 +1606,11 @@ void generator::visit_exp_inst(ir::exp_inst* x){
void generator::visit_cos_inst(ir::cos_inst* x){
std::vector<llvm::Type*> tys = {f32_ty};
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
#ifdef USE_ROCM
llvm::Function *cos = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::cos, tys);
#else
InlineAsm *cos = InlineAsm::get(fn_ty, "cos.approx.f32 $0, $0;", "=f,0", false);
#endif
for(auto idx: idxs_.at(x)){
vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
}
@@ -1589,7 +1636,11 @@ void generator::visit_umulhi_inst(ir::umulhi_inst* x){
void generator::visit_sin_inst(ir::sin_inst* x){
std::vector<llvm::Type*> tys = {f32_ty};
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
#ifdef USE_ROCM
llvm::Function *sin = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::sin, tys);
#else
InlineAsm *sin = InlineAsm::get(fn_ty, "sin.approx.f32 $0, $0;", "=f,0", false);
#endif
for(auto idx: idxs_.at(x)){
vals_[x][idx] = call(sin, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
}
@@ -1602,7 +1653,11 @@ void generator::visit_log_inst(ir::log_inst* x){
Constant *rcplog2e = ConstantFP::get(f32_ty, 0.6931471805599453);
std::vector<llvm::Type*> tys = {f32_ty};
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
#ifdef USE_ROCM
llvm::Function *lg2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::log2, tys);
#else
InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false);
#endif
for(auto idx: idxs_.at(x)){
Value *lg2arg = call(lg2, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
vals_[x][idx] = fmul(lg2arg, rcplog2e);
@@ -1612,6 +1667,35 @@ void generator::visit_log_inst(ir::log_inst* x){
/**
* \brief Code Generation for `atomic_cas`
*/
#if defined(USE_ROCM)
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
Value *tid = tgt_->get_local_id(module, *builder_, 0);
Value *pred = icmp_eq(tid, i32(0));
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
add_barrier();
tgt_->add_memfence(module, *builder_);
cond_br(pred, tid_0_bb, tid_0_done_bb);
builder_->SetInsertPoint(tid_0_bb);
Value *cas_ptr = vals_[cas->get_operand(0)][{}];
Value *cas_cmp = vals_[cas->get_operand(1)][{}];
Value *cas_val = vals_[cas->get_operand(2)][{}];
Value *old = atomic_cmp_xchg(cas_ptr, cas_cmp, cas_val, MaybeAlign(), AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
old = extract_val(old, std::vector<unsigned>{0});
Value *atom_ptr;
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), "");
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
store(old, atom_ptr);
br(tid_0_done_bb);
builder_->SetInsertPoint(tid_0_done_bb);
tgt_->add_memfence(module, *builder_);
add_barrier();
vals_[cas][{}] = load(atom_ptr);
add_barrier();
}
#else
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
@@ -1646,12 +1730,66 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
vals_[cas][{}] = load(atom_ptr);
add_barrier();
}
#endif // defined(USE_ROCM)
/**
* \brief Code Generation for `atomic_rmw`
*/
#if defined(USE_ROCM)
void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
ir::value* ptr = atom->get_operand(0);
if (atom->get_op() == ir::atomic_rmw_op_t::Add ||
atom->get_op() == ir::atomic_rmw_op_t::Max ||
atom->get_op() == ir::atomic_rmw_op_t::Min ||
atom->get_op() == ir::atomic_rmw_op_t::UMax ||
atom->get_op() == ir::atomic_rmw_op_t::UMin ||
atom->get_op() == ir::atomic_rmw_op_t::Xchg) {
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
Value *rmw_ptr = vals_[atom->get_operand(0)][{}];
Value *rmw_val = vals_[atom->get_operand(1)][{}];
Value *tid = tgt_->get_local_id(module, *builder_, 0);
Value *pred = icmp_eq(tid, i32(0));
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
tgt_->add_memfence(module, *builder_);
add_barrier();
cond_br(pred, tid_0_bb, tid_0_done_bb);
builder_->SetInsertPoint(tid_0_bb);
AtomicRMWInst::BinOp binop;
switch (atom->get_op()) {
case ir::atomic_rmw_op_t::Add:
binop = AtomicRMWInst::Add;
break;
case ir::atomic_rmw_op_t::Max:
binop = AtomicRMWInst::Max;
break;
case ir::atomic_rmw_op_t::Min:
binop = AtomicRMWInst::Min;
break;
case ir::atomic_rmw_op_t::UMax:
binop = AtomicRMWInst::UMax;
break;
case ir::atomic_rmw_op_t::UMin:
binop = AtomicRMWInst::UMin;
break;
case ir::atomic_rmw_op_t::Xchg:
binop = AtomicRMWInst::Xchg;
break;
default:
throw std::runtime_error("Not supported!");
}
atomic_rmw(binop, rmw_ptr, rmw_val, MaybeAlign(), AtomicOrdering::Monotonic,
SyncScope::System);
br(tid_0_done_bb);
builder_->SetInsertPoint(tid_0_done_bb);
tgt_->add_memfence(module, *builder_);
return;
}
throw std::runtime_error("Not supported!");
}
#else // defined(USE_ROCM)
void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
ir::value *ptr = atom->get_operand(0);
ir::value* val = atom->get_operand(1);
ir::value* msk = atom->get_operand(2);
@@ -1756,6 +1894,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
}
}
}
#endif // defined(USE_ROCM)
/**
* \brief Code Generation for `mma.884` (V100)
@@ -2834,15 +2973,20 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
size_t red_axis = 1;
unsigned NK = A_shapes[red_axis];
bool is_outer = NK == 1;
#ifdef USE_ROCM
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
#else
bool is_mma = layouts_->get(dot)->to_mma();
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80)
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80)
return visit_mma884(dot, A, B, D, NK);
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
A->get_type()->get_scalar_ty()->is_fp32_ty())
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
throw std::runtime_error("dot has invalid operand type");
#endif
}
void generator::visit_trans_inst(ir::trans_inst* trans) {
@@ -2875,8 +3019,14 @@ Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vec
inline Value* generator::shfl_sync(Value* acc, int32_t i){
Type* ty = acc->getType();
#ifdef USE_ROCM
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
#else
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
#endif
if(ty->getPrimitiveSizeInBits() <= 32)
return call(shfl, {acc, i32(i)});
acc = bit_cast(acc, vec_ty(f32_ty, 2));
@@ -3171,12 +3321,16 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
default: throw std::runtime_error("unreachable");
}
ir::value *arg = x->get_operand(0);
#ifdef USE_ROCM
visit_reducend_inst(x, do_acc, neutral);
#else
bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x);
bool is_a100_mma = layouts_->is_a100_mma(x);
if (is_coalesced_scanline || is_a100_mma)
visit_reducend_inst_fast(x, do_acc, neutral);
else
visit_reducend_inst(x, do_acc, neutral);
#endif
}
/**
@@ -3645,6 +3799,7 @@ Value *generator::cast_shared_layout_ptr(analysis::data_layout *layout,
}
void generator::visit_function(ir::function* fn) {
std::cout << "generator.cc: generator::visit_function:" << std::endl;
idxs_.clear();
vals_.clear();
seen_.clear();
@@ -3654,6 +3809,7 @@ void generator::visit_function(ir::function* fn) {
// set attributes
std::cout << "\t// set attributes" << std::endl;
for(auto attr_pair: fn->attrs()){
unsigned id = attr_pair.first;
for(ir::attribute attr: attr_pair.second)
@@ -3664,19 +3820,24 @@ void generator::visit_function(ir::function* fn) {
}
}
// set metadata
std::cout << "\t// set metadata" << std::endl;
if(tgt_->is_gpu()){
tgt_->set_kernel(*builder_, ctx, mod_, ret);
#ifndef USE_ROCM
Metadata *md_args[] = {
ValueAsMetadata::get(ret),
MDString::get(ctx, "maxntidx"),
ValueAsMetadata::get(i32(num_warps_*32))
};
mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
#endif
}
// set arguments
std::cout << "\t// set arguments" << std::endl;
for(unsigned i = 0; i < fn->args().size(); i++)
vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i);
// create blocks
std::cout << "\t// create blocks" << std::endl;
auto blocks = ir::cfg::reverse_post_order(fn);
for(ir::basic_block *block: blocks) {
BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret);
@@ -3684,6 +3845,8 @@ void generator::visit_function(ir::function* fn) {
}
builder_->SetInsertPoint(bbs_[fn->blocks()[0]]);
// create policies
#ifndef USE_ROCM
std::cout << "\t// create policies" << std::endl;
if(tgt_->as_nvidia()->sm() >= 80)
for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){
std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last";
@@ -3691,15 +3854,23 @@ void generator::visit_function(ir::function* fn) {
InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false);
policies_[evict] = call(iasm);
}
#endif
// initialize layouts
std::cout << "\t// initialize layouts" << std::endl;
for(auto x: layouts_->get_all()){
visit_layout(x.second);
}
// generate LLVM-IR code
std::cout << "\t// generate LLVM-IR code" << std::endl;
for(ir::basic_block *block: blocks)
visit_basic_block(block);
// finalize
std::cout << "\t// finalize" << std::endl;
finalize_function(fn);
// verifyFunction
std::cout << "\t// verifyFunction" << std::endl;
llvm::verifyFunction(*ret);
}
@@ -3723,7 +3894,11 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
Value *_8 = i32(8);
Value *_16 = i32(16);
Value *_32 = i32(32);
#ifdef USE_ROCM
int cc = 1; // generate ir for older CUDA cards
#else
int cc = tgt_->as_nvidia()->sm();
#endif
std::vector<Value*> idx_m;
std::vector<Value*> idx_n;
std::vector<Value*> idx_z;
@@ -4114,6 +4289,7 @@ void generator::packed_type(ir::value* i){
}
void generator::visit(ir::module &src, llvm::Module &dst) {
std::cout << "generator.cc: generator::visit" << std::endl;
mod_ = &dst;
ctx_ = &dst.getContext();
builder_ = new Builder(*ctx_);

View File

@@ -15,10 +15,22 @@ namespace codegen{
// base
nvidia_cu_target* target::as_nvidia() {
return dynamic_cast<nvidia_cu_target*>(this);
#ifdef USE_ROCM
amd_cl_target *target::as_amd()
{
return dynamic_cast<amd_cl_target *>(this);
}
amd_cl_target *target::as_nvidia()
{
return this->as_amd();
}
#else
// causes segfault on ROCM
nvidia_cu_target *target::as_nvidia()
{
return dynamic_cast<nvidia_cu_target *>(this);
}
#endif
bool target::is_gpu() const {
return is_gpu_;
@@ -41,7 +53,8 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un
}
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
throw std::runtime_error("not implemented");
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_waitcnt);
return builder.CreateIntrinsic(Intrinsic::amdgcn_s_waitcnt, {}, {builder.getInt32(0)});
}
@@ -56,7 +69,50 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne
}
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
throw std::runtime_error("not implemented on AMD");
Function &F = *builder.GetInsertBlock()->getParent();
Module *Mod = F.getParent();
// We are indexing into this struct, and want to extract the grid_size_*
// fields.
//
// typedef struct hsa_kernel_dispatch_packet_s {
// uint16_t header;
// uint16_t setup;
// uint16_t workgroup_size_x ;
// uint16_t workgroup_size_y;
// uint16_t workgroup_size_z;
// uint16_t reserved0;
// uint32_t grid_size_x ;
// uint32_t grid_size_y ;
// uint32_t grid_size_z;
// .....
// } hsa_kernel_dispatch_packet_t
//
Function *DispatchPtrFn =
Intrinsic::getDeclaration(Mod, Intrinsic::amdgcn_dispatch_ptr);
CallInst *DispatchPtr = builder.CreateCall(DispatchPtrFn, {});
DispatchPtr->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
DispatchPtr->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
F.removeFnAttr("amdgpu-no-dispatch-ptr");
// Size of the dispatch packet struct.
DispatchPtr->addDereferenceableAttr(AttributeList::ReturnIndex, 64);
Type *I32Ty = Type::getInt32Ty(Mod->getContext());
// TODO: include AMDGPUAS:: declarations.
Value *CastDispatchPtr = builder.CreateBitCast(
DispatchPtr, PointerType::get(I32Ty, 4 /*AMDGPUAS::CONSTANT_ADDRESS*/));
// grid_size_x offset is 3*32bit
assert(ax < 3);
Value *GEP =
builder.CreateConstInBoundsGEP1_64(I32Ty, CastDispatchPtr, ax + 3);
LoadInst *Load = builder.CreateAlignedLoad(I32Ty, GEP, Align(4));
MDNode *MD = MDNode::get(Mod->getContext(), None);
Load->setMetadata(LLVMContext::MD_invariant_load, MD);
return Load; // throw std::runtime_error("not implemented on AMD");
}
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
@@ -156,7 +212,7 @@ Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsi
}
Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
throw std::runtime_error("not implemented");
throw std::runtime_error("not implemented on CPU");
}

View File

@@ -91,7 +91,7 @@ void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder&
if(inst_map.find(inst_op) != inst_map.end())
new_inst->set_operand(k, inst_map.at(inst_op));
}
// handles a ret instruction.
// handles a ret instruciton.
// instead of returning we need to branch to after the function call
if(ir::return_inst* ret = dynamic_cast<ir::return_inst*>(new_inst)) {
if(ir::value* ret_val = ret->get_return_value())

View File

@@ -222,6 +222,7 @@ bool dispatch::hipinit(){
return res;
}
#define HIP_DEFINE0(ret, fname) DEFINE0(hipinit, hip_, ret, fname)
#define HIP_DEFINE1(ret, fname, t1) DEFINE1(hipinit, hip_, ret, fname, t1)
#define HIP_DEFINE2(ret, fname, t1, t2) DEFINE2(hipinit, hip_, ret, fname, t1, t2)
#define HIP_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3)
@@ -278,7 +279,8 @@ HIP_DEFINE2(hipError_t, hipEventCreate, hipEvent_t *, unsigned int)
HIP_DEFINE3(hipError_t, hipEventElapsedTime, float *, hipEvent_t, hipEvent_t)
HIP_DEFINE2(hipError_t, hipEventRecord, hipEvent_t, hipStream_t)
HIP_DEFINE1(hipError_t, hipEventDestroy, hipEvent_t)
// error handling
HIP_DEFINE0(hipError_t, hipGetLastError)
/* ------------------- *
* COMMON

View File

@@ -25,6 +25,8 @@
#endif
#include <memory>
#include <regex>
#include <iomanip>
#include <string>
#include "triton/driver/llvm.h"
#include "triton/driver/dispatch.h"
#include "triton/driver/error.h"
@@ -57,6 +59,8 @@
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/Intrinsics.h"
// end AMD stuff
extern "C"
@@ -67,6 +71,24 @@ extern "C"
int setupterm(char *term, int fildes, int *errret) { return 0; }
}
namespace {
std::string gen_random(const int len) {
static const char alphanum[] =
"0123456789"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz";
std::string tmp_s;
tmp_s.reserve(len);
for (int i = 0; i < len; ++i) {
tmp_s += alphanum[rand() % (sizeof(alphanum) - 1)];
}
return tmp_s;
}
}
namespace triton
{
namespace driver
@@ -266,20 +288,24 @@ namespace triton
/* ------------------------ */
// HIP //
/* ------------------------ */
std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc)
std::tuple<std::string, std::string> llir_to_amdgcn(llvm::Module *module, const std::string &_proc)
{
std::cout << "llvm.cc: llir_to_amdgcn:" << std::endl;
init_llvm();
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "amdgcn-amd-amdhsa";
std::string layout = "";
std::string features;
std::string proc = "gfx908";
std::string features = "+sramecc,-xnack";
std::string proc = _proc;
// name kernel
auto in_time_t = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
std::stringstream cur_time;
cur_time << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d--%I-%M-%S");
std::string kernel_name = module->getModuleIdentifier() + "_" + cur_time.str() + "_" + gen_random(12);
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
@@ -295,7 +321,7 @@ namespace triton
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None,
llvm::CodeGenOpt::Aggressive);
llvm::CodeGenOpt::None);
// set data layout
if (layout.empty())
module->setDataLayout(machine->createDataLayout());
@@ -308,11 +334,10 @@ namespace triton
llvm::raw_svector_ostream stream(buffer);
// create dump files
std::string module_name = module->getModuleIdentifier();
std::error_code ec;
// Save GCN ISA binary.
std::string isabin_path = std::string("/tmp/") + module_name + std::string(".o");
std::string isabin_path = std::string("/tmp/") + kernel_name + std::string(".o");
std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
if (ec)
@@ -323,15 +348,17 @@ namespace triton
// emit
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
pass.run(*module);
// Save GCN ISA.
std::string amdgcn_path = std::string("/tmp/") + module_name + std::string(".gcn");
std::string result(buffer.begin(), buffer.end());
std::ofstream amdgcn(amdgcn_path);
amdgcn << result;
amdgcn.close();
llvm::SmallVector<char, 0> debugBuffer;
llvm::legacy::PassManager debugPass;
llvm::raw_svector_ostream debugStream(debugBuffer);
machine->addPassesToEmitFile(debugPass, debugStream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile); // TODO:cause segfault on REM ops also cause @llvm.amdgcn.if bug
debugPass.run(*module);
std::string amdgcn(debugBuffer.begin(), debugBuffer.end());
// generate HASCO file
std::string hsaco_path = std::string("/tmp/") + module_name + std::string(".hsaco");
std::string hsaco_path = std::string("/tmp/") + kernel_name + std::string(".hsaco");
std::string error_message;
int lld_result =
llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld",
@@ -344,13 +371,14 @@ namespace triton
std::cout << lld_result << std::endl;
}
return hsaco_path;
return std::make_tuple(amdgcn, hsaco_path);
}
hipModule_t amdgpu_to_hipmodule(const std::string &path)
hipModule_t amdgpu_to_hipmodule(const std::string &hsaco_path)
{
std::cout << "llvm.cc: amdgpu_to_hipmodule:" << std::endl;
// Read HSACO.
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
std::ifstream hsaco_file(hsaco_path, std::ios::binary | std::ios::ate);
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
std::vector<unsigned char> hsaco(hsaco_file_size);

View File

@@ -60,6 +60,9 @@ value *builder::get_float16(float val)
value *builder::get_float32(float val)
{ return constant_fp::get(type::get_fp32_ty(ctx_), val); }
value *builder::get_float64(float val)
{ return constant_fp::get(type::get_fp64_ty(ctx_), val); }
value *builder::get_range(int32_t _lo, int32_t _hi) {
constant_int* lo = static_cast<constant_int*>(get_int32(_lo));
constant_int* hi = static_cast<constant_int*>(get_int32(_hi));

View File

@@ -13,6 +13,7 @@ from typing import NamedTuple
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
import torch
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
@@ -32,7 +33,8 @@ def get_build_type():
def use_system_llvm():
if platform.system() == "Windows":
return True
versions = ['-11.0', '-11', '-11-64']
# versions = ['-11.0', '-11', '-11-64']
versions = ['-13.0', '-13', '-13-64']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
return any(p is not None for p in paths)
@@ -51,9 +53,9 @@ def get_thirdparty_packages(triton_cache_path):
Package("pybind11", "pybind11-2.10.0", "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz", "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "")
]
if not use_system_llvm():
# download LLVM if no suitable system LLVM is installed
# donwload LLVM if no suitable system LLVM is installed
packages.append(
Package("llvm", "clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR")
Package("llvm", "clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-13.0.0/clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR")
)
thirdparty_cmake_args = []
@@ -144,6 +146,9 @@ class CMakeBuild(build_ext):
build_args += ["--", '-j' + str(2 * multiprocessing.cpu_count())]
env = os.environ.copy()
if torch.version.hip is not None:
env["TRITON_USE_ROCM"] = "ON"
subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)

View File

@@ -232,7 +232,7 @@ ir::value *store(ir::value *ptr, ir::value *val, std::optional<ir::value *> _mas
----------------------------------------------*/
std::string dot_docstr = R"pbdoc(
Returns the matrix product of two blocks.
The two blocks must be two dimensions and have compatible inner dimensions.
The two blocks must be two dimensionals and have compatible inner dimensions.
:param input: The first block to be multiplied.
:type input: 2D block of scalar-type in {`float16`, `float32`}

View File

@@ -100,7 +100,9 @@ void hip_enqueue(uint64_t stream, uint64_t kernel,
drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2,
block_0, block_1, block_2,
shared_mem, (hipStream_t)stream, nullptr, config);
#ifdef DEBUG_ROCM
drv::dispatch::hipGetLastError();
#endif
}
long pow2_divisor(long N){
@@ -435,7 +437,7 @@ typedef std::map<std::string, py::object> asm_map_t;
// ---------------------------------------
void init_triton_codegen(py::module &&m) {
m.def("compile_ttir",
m.def("compile_ttir_to_ptx",
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
std::ostringstream ttir;
int n_shared_bytes;
@@ -490,11 +492,64 @@ void init_triton_codegen(py::module &&m) {
},
py::return_value_policy::take_ownership);
m.def("compile_ttir_to_amdgcn",
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, const std::string& gfx_arch) {
std::ostringstream ttir;
int n_shared_bytes;
std::string tmp;
std::string amdgcn;
std::string hsaco_path;
std::string name;
{
std::cout << "triton.cc: compile_ttir_to_amdgcn:" << std::endl;
// Scope where the GIL is released
py::gil_scoped_release allow_threads;
name = ir.get_function_list()[0]->get_name();
ir.print(ttir);
llvm::LLVMContext ctx;
// construct extern lib map
triton::codegen::ExternLibMap extern_lib_map;
for (auto item : extern_libs) {
auto name = item.first.cast<std::string>();
auto path = item.second.cast<std::string>();
extern_lib_map.emplace(
name, triton::codegen::create_extern_lib(name, path));
}
int version;
// std::string ptxas_path = drv::path_to_ptxas(version);
// Triton-IR -> AMDGCN LLVM-IR
std::cout << "ttir:" << std::endl;
std::cout << "\t" << ttir.str() << std::endl;
triton::codegen::amd_cl_target target;
auto llvm = triton::codegen::add_passes_to_emit_bin(
ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
llvm::raw_string_ostream llir(tmp);
llir << *llvm;
std::cout << "llir:" << std::endl;
std::cout << "\t" << llir.str() << std::endl;
llir.flush();
// LLVM-IR -> AMDGPU
std::tuple<std::string, std::string> amdgpu = drv::llir_to_amdgcn(llvm.get(), gfx_arch);
amdgcn = std::get<0>(amdgpu);
hsaco_path = std::get<1>(amdgpu);
std::cout << "amdgcn:" << std::endl;
std::cout << "\t" << amdgcn << std::endl;
}
asm_map_t asm_map;
asm_map["ttir"] = py::cast(ttir.str());
asm_map["llir"] = py::cast(tmp);
asm_map["amdgcn"] = py::cast(amdgcn);
asm_map["hsaco_path"] = py::cast(hsaco_path);
return std::make_tuple(name, asm_map, n_shared_bytes);
},
py::return_value_policy::take_ownership);
// ---------------------------------------
// Load provided assembly code into driver
// ---------------------------------------
m.def("load_binary", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
m.def("load_binary_cubin", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
py::gil_scoped_release allow_threads;
// create driver handles
CUfunction fun;
@@ -521,6 +576,47 @@ void init_triton_codegen(py::module &&m) {
},
py::return_value_policy::take_ownership
);
// ---------------------------------------
// Load provided assembly code into driver
// ---------------------------------------
m.def("load_binary_hsaco", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
std::cout << "triton.cc: load_binary_hsaco:" << std::endl;
std::cout << "\tname:" << name << std::endl;
std::cout << "\tdata:" << data << std::endl;
std::cout << "\tn_shared_bytes:" << n_shared_bytes << std::endl;
std::cout << "\tdevice:" << device << std::endl;
py::gil_scoped_release allow_threads;
// create driver handles
std::cout << "\t" << "// create driver handles" << std::endl;
hipFunction_t fun;
hipModule_t mod = drv::amdgpu_to_hipmodule(data);
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
// get allocated registers and spilled registers from the function
std::cout << "\t" << "// get allocated registers and spilled registers from the function" << std::endl;
int n_regs = 0;
int n_spills = 0;
hipFuncAttributes attr;
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
n_regs = attr.numRegs;
n_spills = attr.localSizeBytes / 4;
// set dynamic shared memory if necessary
std::cout << "\t" << "// set dynamic shared memory if necessary" << std::endl;
int shared_optin;
// drv::dispatch::hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device);
if(n_shared_bytes > 49152 && shared_optin > 49152){
// drv::dispatch::hipFuncSetCacheConfig(fun, hipFuncCachePreferShared);
int shared_total, shared_static;
// drv::dispatch::hipDeviceGetAttribute(&shared_total, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device);
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
shared_total = attr.sharedSizeBytes;
// drv::dispatch::hipFuncSetAttribute(fun, hipFuncAttributeMaxDynamicSharedMemorySize, shared_optin - shared_static);
}
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills);
},
py::return_value_policy::take_ownership
);
struct InstanceDescriptor

View File

@@ -128,7 +128,7 @@ elementwise_data = {
1024 * 16: 0.0219,
1024 * 64: 0.0791,
1024 * 256: 0.243,
1024 * 1024: 0.530,
1024 * 1024: 0.534,
1024 * 4096: 0.796,
1024 * 16384: 0.905,
1024 * 65536: 0.939,

View File

@@ -104,9 +104,12 @@ def check_type_supported(dtype):
'''
skip test if dtype is not supported on the current device
'''
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
if torch.version.hip is not None:
pass
else:
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
@@ -123,6 +126,9 @@ def test_empty_kernel(dtype_x, device='cuda'):
# generic test functions
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
if torch.version.hip is not None:
if dtype_x == "bfloat16":
pytest.skip("unary op with bfloat is not supported on AMDGPU")
check_type_supported(dtype_x) # early return if dtype_x is not supported
SIZE = 128
# define the kernel / launch-grid
@@ -230,7 +236,6 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
('int64', 'float32'),
('int64', 'float64'),
('uint16', 'bfloat16'),
('uint16', 'float16'),
('uint16', 'float32'),
('uint32', 'bfloat16'),
('uint32', 'float16'),
@@ -253,8 +258,12 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
for dtype_y in dtypes_with_bfloat16
])
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
if torch.version.hip is not None:
if dtype_x == "bfloat16" and dtype_y == "bfloat16" :
pytest.skip("binary op with bfloat is not supported on AMDGPU")
expr = f' x {op} y'
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
if op == '%' and (dtype_x in dtypes and dtype_y in dtypes):
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
numpy_expr = 'np.fmod(x, y)'
elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', 'bfloat16'):
@@ -605,10 +614,10 @@ def test_tuples():
]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
if dtype_x_str == 'float16':
pytest.skip("Only test atomic float16 ops on devices with sm >= 70")
if torch.version.hip is not None:
# if dtype_x_str in ["uint32","int32","float32"]:
pytest.skip(f"test_atomic_rmw[{dtype_x_str}] currently has segfaults on ROCM")
n_programs = 5
# triton kernel
@@ -675,6 +684,8 @@ def test_tensor_atomic_rmw(axis, device="cuda"):
def test_atomic_cas():
if torch.version.hip is not None:
pytest.skip(f"test_atomic_cas currently has segfaults on ROCM")
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
def change_value(Lock):
@@ -782,6 +793,8 @@ def test_store_bool():
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_f8_xf16_roundtrip(dtype):
if torch.version.hip is not None:
pytest.skip(f"test_f8_xf16_roundtrip currently has segfaults on ROCM")
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
check_type_supported(dtype)
@@ -808,6 +821,8 @@ def test_f8_xf16_roundtrip(dtype):
def test_f16_to_f8_rounding():
if torch.version.hip is not None:
pytest.skip(f"test_atomic_cas currently has segfaults on ROCM")
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
error is the minimum over all float8.
Or the same explanation a bit mathier:
@@ -876,6 +891,9 @@ def test_f16_to_f8_rounding():
for shape in [32, 64, 128, 512]])
def test_reduce1d(op, dtype_str, shape, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
if torch.version.hip is not None:
if dtype_str in ["int8", "int16", "uint8", "uint16"]:
pytest.skip(f"test_reduce1d[{dtype_str}] skipped on ROCM")
# triton kernel
@triton.jit
@@ -934,6 +952,10 @@ reduce_configs2 = [
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
if torch.version.hip is not None:
if dtype_str in ["int8", "int16", "uint8", "uint16"]:
pytest.skip(f"test_reduce2d[{dtype_str}] skipped on ROCM")
# triton kernel
@triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
@@ -1025,13 +1047,18 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
ptx = pgm_contiguous.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if torch.version.hip is None:
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
ptx = pgm_contiguous.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
else:
# TODO add rocm gcn assert
pass
# ---------------
# test dot
@@ -1045,14 +1072,15 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
for dtype in ['float16']
if not (allow_tf32 and (dtype in ['float16']))])
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if cc < 80:
if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
if torch.version.hip is not None:
pass
else:
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80:
if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
M, N, K = 128, 128, 64
num_warps = 8
@@ -1147,15 +1175,18 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
# print(z_ref[:,0], z_tri[:,0])
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif dtype == 'float32':
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
if torch.version.hip is not None:
pass
else:
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif dtype == 'float32':
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
def test_dot_without_load():
@@ -1233,10 +1264,8 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if torch.version.hip is not None:
pytest.skip(f"test_masked_load_shared_memory currently has segfaults on ROCM")
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
M = 32
@@ -1296,16 +1325,20 @@ def test_load_cache_modifier(cache):
tl.store(dst + offsets, x)
pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
if cache == '.cg':
assert 'ld.global.cg' in ptx
assert 'ld.global.ca' not in ptx
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
if torch.version.hip is None:
ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
if cache == '.cg':
assert 'ld.global.cg' in ptx
assert 'ld.global.ca' not in ptx
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
else:
# TODO add rocm gcn assert
pass
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
@@ -1319,11 +1352,15 @@ def test_vectorization(N):
x = tl.load(src + offsets, mask=offsets < N)
tl.store(dst + offsets, x, mask=offsets < N)
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
ptx = pgm.asm["ptx"]
if N % 16 == 0:
assert "ld.global.v4.b32" in ptx
if torch.version.hip is None:
ptx = pgm.asm["ptx"]
if N % 16 == 0:
assert "ld.global.v4.b32" in ptx
else:
assert "ld.global.b32" in ptx
else:
assert "ld.global.b32" in ptx
#TODO add rocm assert
pass
# triton.testing.assert_almost_equal(dst, src[:N])
# ---------------
# test store
@@ -1557,7 +1594,8 @@ def test_num_warps_pow2():
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
('float64', 'libdevice.norm4d', '')])
def test_libdevice_tensor(dtype_str, expr, lib_path):
if torch.version.hip is not None:
pytest.skip(f"test_libdevice_tensor currently has segfaults on ROCM")
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
@@ -1597,6 +1635,8 @@ def test_libdevice_tensor(dtype_str, expr, lib_path):
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('float32', 'libdevice.pow', '')])
def test_libdevice_scalar(dtype_str, expr, lib_path):
if torch.version.hip is not None:
pytest.skip(f"test_libdevice_scalar currently has segfaults on ROCM")
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):

View File

@@ -2,7 +2,6 @@ import pytest
import torch
import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@@ -126,10 +125,6 @@ def test_attention_fwd_bwd(
batch_size=2,
n_heads=2,
):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
# inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [

View File

@@ -68,8 +68,6 @@ import triton._C.libtriton.triton as _triton
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if cc < 80 and DTYPE == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
if DTYPE == "bfloat16" and SPLIT_K != 1:

View File

@@ -7,6 +7,7 @@ import hashlib
import io
import json
import os
import re
import shutil
import subprocess
import sys
@@ -24,6 +25,12 @@ import triton
import triton._C.libtriton.triton as _triton
from .tools.disasm import extract
def static_vars(**kwargs):
def decorate(func):
for k in kwargs:
setattr(func, k, kwargs[k])
return func
return decorate
def str_to_ty(name):
if name[0] == "*":
@@ -880,31 +887,55 @@ def ptx_get_kernel_name(ptx: str) -> str:
if line.startswith('// .globl'):
return line.split()[-1]
@functools.lru_cache()
def rocm_path_dir():
return os.getenv("ROCM_PATH", default="/opt/rocm")
def _get_amdgpu_arch():
try:
rocminfo = subprocess.check_output(rocm_path_dir() + '/bin/rocminfo').decode()
gfx_arch = re.search('Name:\\s+.*(gfx\\d+)', rocminfo)
return gfx_arch.group(1).strip()
except:
return None
@static_vars(discovered_gfx_arch = _get_amdgpu_arch())
def _compile(fn, signature: str, device: int = -1, constants=dict(),
specialization=_triton.code_gen.instance_descriptor(),
num_warps: int = 4, num_stages: int = 3, extern_libs=None,
output: str = "ttgir", cc=0) -> Tuple[str, int, str]:
print("compiler.py: _compile")
print(f"\t{fn, signature, device, constants, specialization, num_warps, num_stages, extern_libs, output, cc}")
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# triton-ir
module, _ = make_triton_ir(fn, signature, specialization, constants)
if output == "ttir":
return module
assert output == "cubin"
assert torch.version.hip is None
backend = _triton.runtime.backend.CUDA
assert (output == "cubin" or output == "hsaco")
if torch.version.hip is not None:
backend = _triton.runtime.backend.ROCM
else:
backend = _triton.runtime.backend.CUDA
if extern_libs is None:
extern_libs = dict()
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs, cc)
# compile ttir
if torch.version.hip is not None:
gfx_arch = os.environ.get('MI_GPU_ARCH', _compile.discovered_gfx_arch)
if gfx_arch is None:
raise RuntimeError('AMDGCN gfx arch is not defined.')
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgcn(backend, module, device, num_warps, num_stages, extern_libs, gfx_arch)
else:
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_ptx(backend, module, device, num_warps, num_stages, extern_libs, cc)
return asm, shared_mem, name
def ty_to_cpp(ty):
if ty[0] == '*':
return "CUdeviceptr"
return "hipDeviceptr_t"
return {
"i1": "int32_t",
"i8": "int8_t",
@@ -967,17 +998,18 @@ def generate_launcher(identifier, constants, signature):
format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
# generate glue code
src = f"""
#include \"cuda.h\"
if torch.version.hip is not None:
src = f"""
#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line)
static inline void gpuAssert(hipError_t code, const char *file, int line)
{{
if (code != CUDA_SUCCESS)
if (code != HIP_SUCCESS)
{{
const char* prefix = "Triton Error [CUDA]: ";
const char* str;
cuGetErrorString(code, &str);
const char* str = hipGetErrorString(code);
char err[1024] = {{0}};
strcat(err, prefix);
strcat(err, str);
@@ -987,20 +1019,20 @@ static inline void gpuAssert(CUresult code, const char *file, int line)
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
if(gridX*gridY*gridZ > 0){{
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
hipModuleLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0);
}}
}}
static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
static inline hipDeviceptr_t getPointer(PyObject *obj, int idx) {{
if (PyLong_Check(obj)) {{
return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj);
return (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
}}
if (obj == Py_None) {{
return (CUdeviceptr)0;
return (hipDeviceptr_t)0;
}}
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if(ptr){{
@@ -1011,14 +1043,15 @@ static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
}}
return (CUdeviceptr)PyLong_AsUnsignedLongLong(ret);
return (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
return (CUdeviceptr)0;
return (hipDeviceptr_t)0;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
// printf("launch(PyObject* self, PyObject* args)");
int gridX, gridY, gridZ;
uint64_t _stream;
uint64_t _function;
@@ -1039,7 +1072,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
Py_DECREF(new_args);
}}
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
if (launch_exit_hook != Py_None) {{
PyObject *new_args = NULL;
@@ -1126,7 +1159,7 @@ class CacheManager:
os.rename(filepath + ".tmp", filepath)
# utilities for generating and compiling C wrappers
# utilties for generating and compiling C wrappers
@functools.lru_cache()
@@ -1134,6 +1167,10 @@ def libcuda_dirs():
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
return [os.path.dirname(loc) for loc in locs]
@functools.lru_cache()
def libhip_dirs():
return ["/opt/rocm/lib"]
@functools.lru_cache()
def cuda_home_dirs():
@@ -1141,6 +1178,15 @@ def cuda_home_dirs():
return os.getenv("CUDA_HOME", default=default_dir)
@functools.lru_cache()
def hip_home_dirs():
default_dir = "/opt/rocm"
return os.getenv("ROCM_HOME", default=default_dir)
@functools.lru_cache()
def rocm_path_dir():
return os.getenv("ROCM_PATH", default="/opt/rocm")
@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
@@ -1152,8 +1198,15 @@ def quiet():
def _build(name, src, srcdir):
cuda_lib_dirs = libcuda_dirs()
cu_include_dir = os.path.join(cuda_home_dirs(), "include")
print("compiler.py: _build")
print(f"\t{name, src, srcdir}")
if torch.version.hip is not None:
hip_lib_dirs = libhip_dirs()
hip_include_dir = os.path.join(hip_home_dirs(), "include")
else:
cuda_lib_dirs = libcuda_dirs()
cu_include_dir = os.path.join(cuda_home_dirs(), "include")
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible
@@ -1164,16 +1217,29 @@ def _build(name, src, srcdir):
gcc = shutil.which("gcc")
cc = gcc if gcc is not None else clang
py_include_dir = get_paths()["include"]
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
if torch.version.hip is not None:
cc_cmd = [cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lamdhip64", "-o", so]
cc_cmd += [f"-L{dir}" for dir in hip_lib_dirs]
else:
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
print("\t", ''.join(cc_cmd))
ret = subprocess.check_call(cc_cmd)
if ret == 0:
print("ret:", ret)
print(so)
return so
# fallback on setuptools
extra_compile_args = []
library_dirs = cuda_lib_dirs
include_dirs = [srcdir, cu_include_dir]
libraries = ['cuda']
if torch.version.hip is not None:
library_dirs = hip_lib_dirs
include_dirs = [srcdir, hip_include_dir]
libraries = ['rocm']
else:
library_dirs = cuda_lib_dirs
include_dirs = [srcdir, cu_include_dir]
libraries = ['cuda']
# extra arguments
extra_link_args = []
# create extension module
@@ -1221,6 +1287,8 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4,
num_stages: int = 3, extern_libs=None, configs=None, cc=0, warm_cache_only=False):
print("compiler.py: compile")
print(f"\t{fn, signature, device, constants, num_warps, num_stages, extern_libs, configs, cc, warm_cache_only}")
# we get the kernel, i.e. the first function generated in the module
assert len(configs) == 1
# cache manager
@@ -1236,30 +1304,52 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build(fn.__name__, src_path, tmpdir)
so = _build(fn.__name__, src_path, tmpdir) # build step
with open(so, "rb") as f:
so_cache_manager.put(f.read(), so_name, binary=True)
# retrieve cached shared object if it exists
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
fn_cache_manager = CacheManager(fn_cache_key)
ptx_name = f"{name}.ptx"
cubin_name = f"{name}.cubin"
if torch.version.hip is not None:
amdgcn_name = f"{name}.gcn"
hasco_name = f"{name}.hsaco"
assembly_name = amdgcn_name
binary_name = hasco_name
else:
ptx_name = f"{name}.ptx"
cubin_name = f"{name}.cubin"
assembly_name = ptx_name
binary_name = cubin_name
data_name = f"{name}.json"
ttir_name = f"{name}.ttir"
llir_name = f"{name}.llir"
if not fn_cache_manager.has_file(cubin_name) or \
if not fn_cache_manager.has_file(binary_name) or \
not fn_cache_manager.has_file(data_name) or \
not fn_cache_manager.has_file(ptx_name) or \
not fn_cache_manager.has_file(assembly_name) or \
not fn_cache_manager.has_file(ttir_name) or \
not fn_cache_manager.has_file(llir_name):
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
extern_libs, "cubin", cc)
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
fn_cache_manager.put(asm["cubin"], cubin_name)
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
if torch.version.hip is not None:
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
extern_libs, "hsaco", cc)
# cache AMD assembly and binary
fn_cache_manager.put(asm["hsaco_path"], binary_name, binary=False)
fn_cache_manager.put(asm["amdgcn"], assembly_name, binary=False)
else:
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
extern_libs, "cubin", cc)
# cache Nvidia assembly and binary
fn_cache_manager.put(asm["cubin"], binary_name)
fn_cache_manager.put(asm["ptx"], assembly_name, binary=False)
# cache triton and llvm ir
fn_cache_manager.put(asm["ttir"], ttir_name, binary=False)
fn_cache_manager.put(asm["llir"], llir_name, binary=False)
# cache metadata
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
if warm_cache_only:
@@ -1275,9 +1365,12 @@ class CompiledKernel:
launch_exit_hook = None
def __init__(self, fn_name, so_path, cache_dir, device):
print("compiler.py: CompiledKernel:__init__")
print(f"\t{self, fn_name, so_path, cache_dir, device}")
# initialize launcher
import importlib.util
spec = importlib.util.spec_from_file_location("launcher", so_path)
print("spec:", spec)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.c_wrapper = getattr(mod, "launch")
@@ -1289,16 +1382,25 @@ class CompiledKernel:
self.num_stages = metadata["num_stages"]
# initialize asm dict
self.asm = dict()
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
self.asm["cubin"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
self.asm["ptx"] = f.read()
if torch.version.hip is not None:
with open(os.path.join(cache_dir, f"{fn_name}.hsaco"), "rb") as f:
self.asm["hsaco_path"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.gcn"), "r") as f:
self.asm["amdgcn"] = f.read()
else:
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
self.asm["cubin"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
self.asm["ptx"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.llir"), "r") as f:
self.asm["llir"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.ttir"), "r") as f:
self.asm["ttir"] = f.read()
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
if torch.version.hip is not None:
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_hsaco(metadata["name"], self.asm["hsaco_path"], self.shared, device)
else:
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_cubin(metadata["name"], self.asm["cubin"], self.shared, device)
self.fn_name = fn_name
self.cu_module = mod
self.cu_function = func

View File

@@ -768,7 +768,7 @@ def dot(input, other, trans_a=False, trans_b=False, allow_tf32=True, _builder=No
"""
Returns the matrix product of two blocks.
The two blocks must be two dimensions and have compatible inner dimensions.
The two blocks must be two dimensionals and have compatible inner dimensions.
:param input: The first tensor to be multiplied.
:type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}

View File

@@ -58,7 +58,13 @@ def mulhi(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.int32, core.int32,): ("__nv_mulhi", core.int32),
(core.uint32, core.uint32,): ("__nv_umulhi", core.uint32),
(core.int64, core.int64,): ("__nv_mul64hi", core.int64),
}, _builder)
@extern.extern
def mul64hi(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.int64, core.int64,): ("__nv_mul64hi", core.int64),
(core.uint64, core.uint64,): ("__nv_umul64hi", core.uint64),
}, _builder)
@@ -152,137 +158,261 @@ def saturatef(arg0, _builder=None):
@extern.extern
def fma_rn(arg0, arg1, arg2, _builder=None):
def fmaf_rn(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rn", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_rn", core.float64),
}, _builder)
@extern.extern
def fmaf_rz(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rz", core.float32),
}, _builder)
@extern.extern
def fmaf_rd(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rd", core.float32),
}, _builder)
@extern.extern
def fmaf_ru(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ru", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_rn(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rn", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_rz(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rz", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_rd(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_rd", core.float32),
}, _builder)
@extern.extern
def fmaf_ieee_ru(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ieee_ru", core.float32),
}, _builder)
@extern.extern
def fma_rn(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float64, core.float64, core.float64,): ("__nv_fma_rn", core.float64),
}, _builder)
@extern.extern
def fma_rz(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rz", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_rz", core.float64),
{(core.float64, core.float64, core.float64,): ("__nv_fma_rz", core.float64),
}, _builder)
@extern.extern
def fma_rd(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_rd", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_rd", core.float64),
{(core.float64, core.float64, core.float64,): ("__nv_fma_rd", core.float64),
}, _builder)
@extern.extern
def fma_ru(arg0, arg1, arg2, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
{(core.float32, core.float32, core.float32,): ("__nv_fmaf_ru", core.float32),
(core.float64, core.float64, core.float64,): ("__nv_fma_ru", core.float64),
{(core.float64, core.float64, core.float64,): ("__nv_fma_ru", core.float64),
}, _builder)
@extern.extern
def fast_dividef(arg0, arg1, _builder=None):
def fast_fdividef(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fast_fdividef", core.float32),
}, _builder)
@extern.extern
def div_rn(arg0, arg1, _builder=None):
def fdiv_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_rn", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_rn", core.float64),
}, _builder)
@extern.extern
def div_rz(arg0, arg1, _builder=None):
def fdiv_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_rz", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_rz", core.float64),
}, _builder)
@extern.extern
def div_rd(arg0, arg1, _builder=None):
def fdiv_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_rd", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_rd", core.float64),
}, _builder)
@extern.extern
def div_ru(arg0, arg1, _builder=None):
def fdiv_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fdiv_ru", core.float32),
(core.float64, core.float64,): ("__nv_ddiv_ru", core.float64),
}, _builder)
@extern.extern
def rcp_rn(arg0, _builder=None):
def frcp_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_rn", core.float32),
(core.float64,): ("__nv_drcp_rn", core.float64),
}, _builder)
@extern.extern
def rcp_rz(arg0, _builder=None):
def frcp_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_rz", core.float32),
(core.float64,): ("__nv_drcp_rz", core.float64),
}, _builder)
@extern.extern
def rcp_rd(arg0, _builder=None):
def frcp_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_rd", core.float32),
(core.float64,): ("__nv_drcp_rd", core.float64),
}, _builder)
@extern.extern
def rcp_ru(arg0, _builder=None):
def frcp_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frcp_ru", core.float32),
(core.float64,): ("__nv_drcp_ru", core.float64),
}, _builder)
@extern.extern
def sqrt_rn(arg0, _builder=None):
def fsqrt_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_rn", core.float32),
(core.float64,): ("__nv_dsqrt_rn", core.float64),
}, _builder)
@extern.extern
def sqrt_rz(arg0, _builder=None):
def fsqrt_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_rz", core.float32),
(core.float64,): ("__nv_dsqrt_rz", core.float64),
}, _builder)
@extern.extern
def sqrt_rd(arg0, _builder=None):
def fsqrt_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_rd", core.float32),
(core.float64,): ("__nv_dsqrt_rd", core.float64),
}, _builder)
@extern.extern
def sqrt_ru(arg0, _builder=None):
def fsqrt_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_fsqrt_ru", core.float32),
(core.float64,): ("__nv_dsqrt_ru", core.float64),
}, _builder)
@extern.extern
def ddiv_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_rn", core.float64),
}, _builder)
@extern.extern
def ddiv_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_rz", core.float64),
}, _builder)
@extern.extern
def ddiv_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_rd", core.float64),
}, _builder)
@extern.extern
def ddiv_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_ddiv_ru", core.float64),
}, _builder)
@extern.extern
def drcp_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_rn", core.float64),
}, _builder)
@extern.extern
def drcp_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_rz", core.float64),
}, _builder)
@extern.extern
def drcp_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_rd", core.float64),
}, _builder)
@extern.extern
def drcp_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_drcp_ru", core.float64),
}, _builder)
@extern.extern
def dsqrt_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_rn", core.float64),
}, _builder)
@extern.extern
def dsqrt_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_rz", core.float64),
}, _builder)
@extern.extern
def dsqrt_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_rd", core.float64),
}, _builder)
@extern.extern
def dsqrt_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_dsqrt_ru", core.float64),
}, _builder)
@@ -295,66 +425,114 @@ def sqrt(arg0, _builder=None):
@extern.extern
def add_rn(arg0, arg1, _builder=None):
def dadd_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_rn", core.float64),
(core.float32, core.float32,): ("__nv_fadd_rn", core.float32),
}, _builder)
@extern.extern
def add_rz(arg0, arg1, _builder=None):
def dadd_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_rz", core.float64),
(core.float32, core.float32,): ("__nv_fadd_rz", core.float32),
}, _builder)
@extern.extern
def add_rd(arg0, arg1, _builder=None):
def dadd_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_rd", core.float64),
(core.float32, core.float32,): ("__nv_fadd_rd", core.float32),
}, _builder)
@extern.extern
def add_ru(arg0, arg1, _builder=None):
def dadd_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dadd_ru", core.float64),
(core.float32, core.float32,): ("__nv_fadd_ru", core.float32),
}, _builder)
@extern.extern
def mul_rn(arg0, arg1, _builder=None):
def dmul_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_rn", core.float64),
(core.float32, core.float32,): ("__nv_fmul_rn", core.float32),
}, _builder)
@extern.extern
def mul_rz(arg0, arg1, _builder=None):
def dmul_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_rz", core.float64),
(core.float32, core.float32,): ("__nv_fmul_rz", core.float32),
}, _builder)
@extern.extern
def mul_rd(arg0, arg1, _builder=None):
def dmul_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_rd", core.float64),
(core.float32, core.float32,): ("__nv_fmul_rd", core.float32),
}, _builder)
@extern.extern
def mul_ru(arg0, arg1, _builder=None):
def dmul_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dmul_ru", core.float64),
(core.float32, core.float32,): ("__nv_fmul_ru", core.float32),
}, _builder)
@extern.extern
def fadd_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_rd", core.float32),
}, _builder)
@extern.extern
def fadd_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_ru", core.float32),
}, _builder)
@extern.extern
def fmul_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_rd", core.float32),
}, _builder)
@extern.extern
def fmul_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_ru", core.float32),
}, _builder)
@extern.extern
def fadd_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_rn", core.float32),
}, _builder)
@extern.extern
def fadd_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fadd_rz", core.float32),
}, _builder)
@extern.extern
def fmul_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_rn", core.float32),
}, _builder)
@extern.extern
def fmul_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fmul_rz", core.float32),
}, _builder)
@@ -446,13 +624,7 @@ def double2uint_ru(arg0, _builder=None):
def int2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2double_rn", core.float64),
}, _builder)
@extern.extern
def uint2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2double_rn", core.float64),
(core.uint32,): ("__nv_uint2double_rn", core.float64),
}, _builder)
@@ -516,6 +688,7 @@ def float2uint_ru(arg0, _builder=None):
def int2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_rn", core.float32),
(core.uint32,): ("__nv_uint2float_rn", core.float32),
}, _builder)
@@ -523,6 +696,7 @@ def int2float_rn(arg0, _builder=None):
def int2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_rz", core.float32),
(core.uint32,): ("__nv_uint2float_rz", core.float32),
}, _builder)
@@ -530,6 +704,7 @@ def int2float_rz(arg0, _builder=None):
def int2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_rd", core.float32),
(core.uint32,): ("__nv_uint2float_rd", core.float32),
}, _builder)
@@ -537,34 +712,7 @@ def int2float_rd(arg0, _builder=None):
def int2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int2float_ru", core.float32),
}, _builder)
@extern.extern
def uint2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_rn", core.float32),
}, _builder)
@extern.extern
def uint2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_rz", core.float32),
}, _builder)
@extern.extern
def uint2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_rd", core.float32),
}, _builder)
@extern.extern
def uint2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint2float_ru", core.float32),
(core.uint32,): ("__nv_uint2float_ru", core.float32),
}, _builder)
@@ -705,6 +853,7 @@ def double2ull_ru(arg0, _builder=None):
def ll2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_rn", core.float32),
(core.uint64,): ("__nv_ull2float_rn", core.float32),
}, _builder)
@@ -712,6 +861,7 @@ def ll2float_rn(arg0, _builder=None):
def ll2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_rz", core.float32),
(core.uint64,): ("__nv_ull2float_rz", core.float32),
}, _builder)
@@ -719,6 +869,7 @@ def ll2float_rz(arg0, _builder=None):
def ll2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_rd", core.float32),
(core.uint64,): ("__nv_ull2float_rd", core.float32),
}, _builder)
@@ -726,34 +877,7 @@ def ll2float_rd(arg0, _builder=None):
def ll2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2float_ru", core.float32),
}, _builder)
@extern.extern
def ull2float_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_rn", core.float32),
}, _builder)
@extern.extern
def ull2float_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_rz", core.float32),
}, _builder)
@extern.extern
def ull2float_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_rd", core.float32),
}, _builder)
@extern.extern
def ull2float_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2float_ru", core.float32),
(core.uint64,): ("__nv_ull2float_ru", core.float32),
}, _builder)
@@ -761,6 +885,7 @@ def ull2float_ru(arg0, _builder=None):
def ll2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_rn", core.float64),
(core.uint64,): ("__nv_ull2double_rn", core.float64),
}, _builder)
@@ -768,6 +893,7 @@ def ll2double_rn(arg0, _builder=None):
def ll2double_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_rz", core.float64),
(core.uint64,): ("__nv_ull2double_rz", core.float64),
}, _builder)
@@ -775,6 +901,7 @@ def ll2double_rz(arg0, _builder=None):
def ll2double_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_rd", core.float64),
(core.uint64,): ("__nv_ull2double_rd", core.float64),
}, _builder)
@@ -782,34 +909,7 @@ def ll2double_rd(arg0, _builder=None):
def ll2double_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int64,): ("__nv_ll2double_ru", core.float64),
}, _builder)
@extern.extern
def ull2double_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_rn", core.float64),
}, _builder)
@extern.extern
def ull2double_rz(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_rz", core.float64),
}, _builder)
@extern.extern
def ull2double_rd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_rd", core.float64),
}, _builder)
@extern.extern
def ull2double_ru(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint64,): ("__nv_ull2double_ru", core.float64),
(core.uint64,): ("__nv_ull2double_ru", core.float64),
}, _builder)
@@ -817,6 +917,7 @@ def ull2double_ru(arg0, _builder=None):
def int_as_float(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.int32,): ("__nv_int_as_float", core.float32),
(core.uint32,): ("__nv_uint_as_float", core.float32),
}, _builder)
@@ -827,13 +928,6 @@ def float_as_int(arg0, _builder=None):
}, _builder)
@extern.extern
def uint_as_float(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.uint32,): ("__nv_uint_as_float", core.float32),
}, _builder)
@extern.extern
def float_as_uint(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
@@ -912,9 +1006,11 @@ def fast_log10f(arg0, _builder=None):
@extern.extern
def fast_powf(arg0, arg1, _builder=None):
def pow(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fast_powf", core.float32),
(core.float32, core.float32,): ("__nv_powf", core.float32),
(core.float64, core.float64,): ("__nv_pow", core.float64),
}, _builder)
@@ -935,39 +1031,35 @@ def rhadd(arg0, arg1, _builder=None):
@extern.extern
def sub_rn(arg0, arg1, _builder=None):
def fsub_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_rn", core.float32),
(core.float64, core.float64,): ("__nv_dsub_rn", core.float64),
}, _builder)
@extern.extern
def sub_rz(arg0, arg1, _builder=None):
def fsub_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_rz", core.float32),
(core.float64, core.float64,): ("__nv_dsub_rz", core.float64),
}, _builder)
@extern.extern
def sub_rd(arg0, arg1, _builder=None):
def fsub_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_rd", core.float32),
(core.float64, core.float64,): ("__nv_dsub_rd", core.float64),
}, _builder)
@extern.extern
def sub_ru(arg0, arg1, _builder=None):
def fsub_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.float32,): ("__nv_fsub_ru", core.float32),
(core.float64, core.float64,): ("__nv_dsub_ru", core.float64),
}, _builder)
@extern.extern
def rsqrt_rn(arg0, _builder=None):
def frsqrt_rn(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_frsqrt_rn", core.float32),
}, _builder)
@@ -1006,18 +1098,16 @@ def nearbyint(arg0, _builder=None):
@extern.extern
def isnan(arg0, _builder=None):
def isnanf(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_isnanf", core.int32),
(core.float64,): ("__nv_isnand", core.int32),
}, _builder)
@extern.extern
def signbit(arg0, _builder=None):
def signbitf(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_signbitf", core.int32),
(core.float64,): ("__nv_signbitd", core.int32),
}, _builder)
@@ -1037,10 +1127,9 @@ def finitef(arg0, _builder=None):
@extern.extern
def isinf(arg0, _builder=None):
def isinff(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float32,): ("__nv_isinff", core.int32),
(core.float64,): ("__nv_isinfd", core.int32),
}, _builder)
@@ -1461,12 +1550,10 @@ def fma(arg0, arg1, arg2, _builder=None):
@extern.extern
def pow(arg0, arg1, _builder=None):
def powi(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float32, core.int32,): ("__nv_powif", core.float32),
(core.float64, core.int32,): ("__nv_powi", core.float64),
(core.float32, core.float32,): ("__nv_powf", core.float32),
(core.float64, core.float64,): ("__nv_pow", core.float64),
}, _builder)
@@ -1518,8 +1605,57 @@ def logb(arg0, _builder=None):
}, _builder)
@extern.extern
def signbitd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_signbitd", core.int32),
}, _builder)
@extern.extern
def isfinited(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_isfinited", core.int32),
}, _builder)
@extern.extern
def isinfd(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_isinfd", core.int32),
}, _builder)
@extern.extern
def isnand(arg0, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ],
{(core.float64,): ("__nv_isnand", core.int32),
}, _builder)
@extern.extern
def dsub_rn(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_rn", core.float64),
}, _builder)
@extern.extern
def dsub_rz(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_rz", core.float64),
}, _builder)
@extern.extern
def dsub_ru(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_ru", core.float64),
}, _builder)
@extern.extern
def dsub_rd(arg0, arg1, _builder=None):
return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ],
{(core.float64, core.float64,): ("__nv_dsub_rd", core.float64),
}, _builder)

View File

@@ -6,6 +6,8 @@ from . import core as tl
from triton._C.libtriton.triton import ir
import torch
# Create custom exception that prints message "hello"
class IncompatibleTypeErrorimpl(Exception):
def __init__(self, type_a, type_b):
@@ -969,6 +971,11 @@ def dot(a: tl.tensor,
trans_b: bool,
allow_tf32: bool,
builder: ir.builder) -> tl.tensor:
if torch.version.hip is not None:
a = cast(a, tl.float32, builder)
b = cast(b, tl.float32, builder)
in_a = 1 if not trans_a else 0
in_b = 1 if trans_b else 0
assert a.type.is_block() and b.type.is_block()

View File

@@ -270,7 +270,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
raise TypeError(f"Callable constexpr at index {i} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
if not warmup:

View File

@@ -115,9 +115,7 @@ def nvsmi(attrs):
return ret
def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
percentiles=(0.5, 0.2, 0.8),
record_clocks=False, fast_flush=False):
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
@@ -132,8 +130,6 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
:type grad_to_none: torch.tensor, optional
:param percentiles: Performance percentile to return in addition to the median.
:type percentiles: list[float]
:param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool
"""
# Estimate the runtime of the function
@@ -155,10 +151,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
# doesn't contain any input data before the run
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
if fast_flush:
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
else:
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
# Warm-up
for _ in range(n_warmup):
fn()

View File

@@ -1,24 +1,10 @@
import argparse
import subprocess
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
class Symbol:
_name: str
_op_name: str
_ret_type: str
_arg_names: List[str]
_arg_types: List[str]
def __init__(
self,
name: str,
op_name: str,
ret_type: str,
arg_names: List[str],
arg_types: List[str],
) -> None:
def __init__(self, name: str, op_name: str, ret_type: str, arg_names: list, arg_types: list) -> None:
'''
A symbol is a function declaration.
@@ -31,31 +17,31 @@ class Symbol:
self._name = name
self._op_name = op_name
self._ret_type = ret_type
self._arg_names = list(arg_names)
self._arg_types = list(arg_types)
self._arg_names = arg_names
self._arg_types = arg_types
@property
def name(self) -> str:
def name(self):
return self._name
@property
def op_name(self) -> str:
def op_name(self):
return self._op_name
@property
def ret_type(self) -> str:
def ret_type(self):
return self._ret_type
@property
def arg_names(self) -> List[str]:
def arg_names(self):
return self._arg_names
@property
def arg_types(self) -> List[str]:
def arg_types(self):
return self._arg_types
def convert_type(type_str) -> Optional[str]:
def convert_type(type_str):
if type_str == "i32":
return "int32"
elif type_str == "u32":
@@ -73,7 +59,7 @@ def convert_type(type_str) -> Optional[str]:
return None
def to_unsigned(type_str) -> str:
def to_unsigned(type_str):
if type_str == "int32":
return "uint32"
elif type_str == "int64":
@@ -83,19 +69,7 @@ def to_unsigned(type_str) -> str:
class ExternLibrary(ABC):
_name: str
_path: str
_symbols: Dict[str, Symbol]
_format: bool
_grouping: bool
def __init__(
self,
name: str,
path: str,
format: bool = True,
grouping: bool = True,
) -> None:
def __init__(self, name: str, path: str, format: bool = True, grouping: bool = True) -> None:
'''
Abstract class for extern library.
@@ -106,34 +80,34 @@ class ExternLibrary(ABC):
self._name = name
self._path = path
self._symbols = {}
self._format = format
self._format = True
self._grouping = grouping
@property
def name(self) -> str:
def name(self):
return self._name
@property
def path(self) -> str:
def path(self):
return self._path
@property
def symbols(self) -> Dict[str, Symbol]:
def symbols(self):
return self._symbols
@property
def grouping(self) -> bool:
def grouping(self):
return self._grouping
@abstractmethod
def parse_symbols(self, input_file) -> None:
def parse_symbols(self, input_file):
pass
@abstractmethod
def _output_stubs(self) -> str:
pass
def generate_stub_file(self, output_dir) -> None:
def generate_stub_file(self, output_dir):
file_str = self._output_stubs()
if file_str is None or len(file_str) == 0:
raise Exception("file_str is empty")
@@ -149,8 +123,6 @@ class ExternLibrary(ABC):
class Libdevice(ExternLibrary):
_symbol_groups: Dict[str, List[Symbol]]
def __init__(self, path) -> None:
'''
Constructor for Libdevice.
@@ -160,7 +132,7 @@ class Libdevice(ExternLibrary):
super().__init__("libdevice", path)
self._symbol_groups = {}
def _extract_symbol(self, line) -> Optional[Symbol]:
def _extract_symbol(self, line):
# Extract symbols from line in the following format:
# "define [internal] <ret_type> @<name>(<arg_types>,)"
entries = line.split("@")
@@ -177,9 +149,6 @@ class Libdevice(ExternLibrary):
func_strs = func_str.split("(")
func_name = func_strs[0].replace("@", "")
op_name = func_name.replace("__nv_", "")
# To filter some interfaces unlisted in NVIDIA's official documents.
if 'ieee' in op_name:
return None
# Get arg_types
arg_strs = func_strs[1].split(",")
arg_types = []
@@ -202,77 +171,66 @@ class Libdevice(ExternLibrary):
arg_types[i] = to_unsigned(arg_type)
return Symbol(func_name, op_name, ret_type, arg_names, arg_types)
def _group_symbols(self) -> None:
def _group_symbols(self):
symbol_set = {}
for symbol in self._symbols.values():
op_name = symbol.op_name
symbol_set[op_name] = symbol
# Group functions together by renaming.
renaming = {
'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh',
'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn',
'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru',
'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin',
'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2',
'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt',
'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign',
'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi',
'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn',
'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru',
'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf',
'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx',
'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10',
'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs',
'fabs': 'abs', 'fast_fdividef': 'fast_dividef',
'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor',
'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn',
'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod',
'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb',
'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan',
'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint',
'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10',
'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb',
'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max',
'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min',
'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd',
'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru',
'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz',
'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi',
'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter',
'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf',
'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow',
'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd',
'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru',
'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz',
'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d',
'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn',
'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit',
'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh',
'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd',
'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn',
'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz',
'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd',
'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru',
'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc',
'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn'
}
# The following cases are grouped together:
# op_name, <u/ull/ll>op_name<ll/f/i>
for symbol in self._symbols.values():
op_name = symbol.op_name
if op_name in renaming:
op_name = renaming[op_name]
if "max" in op_name:
op_name = "max"
elif "min" in op_name:
op_name = "min"
elif "abs" in op_name:
op_name = "abs"
elif "pow" in op_name and "fast" in op_name:
op_name = "pow"
elif "round" in op_name:
if "llround" in op_name:
op_name = "llround"
else:
op_name = "round"
elif "rint" in op_name:
if "llrint" in op_name:
op_name = "llrint"
else:
op_name = "rint"
elif op_name.startswith("ull"):
if "2" not in op_name:
# e.g., ullmax->max
op_name = op_name[3:]
else:
# e.g., ull2double->ll2double
op_name = op_name[1:]
elif op_name.startswith("u"):
if "2" not in op_name:
# e.g., uhadd->hadd
op_name = op_name[1:]
else:
# e.g., uint2double_rn->int2double_rn
op_name = op_name[1:]
elif op_name.startswith("ll"):
if "2" not in op_name:
# e.g., llmax->max
op_name = op_name[2:]
elif op_name.endswith("ll"):
op_name = op_name[:-2]
elif op_name.endswith("f"):
op_name = op_name[:-1]
if op_name in symbol_set:
# Update op_name only if there's an existing symbol
symbol._op_name = op_name
else:
op_name = symbol._op_name
if op_name in self._symbol_groups:
self._symbol_groups[op_name].append(symbol)
else:
self._symbol_groups[op_name] = [symbol]
def parse_symbols(self, input_file) -> None:
def parse_symbols(self, input_file):
if len(self.symbols) > 0:
return
output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines()
@@ -284,7 +242,7 @@ class Libdevice(ExternLibrary):
self._group_symbols()
def _output_stubs(self) -> str:
def _output_stubs(self):
# Generate python functions in the following format:
# @extern.extern
# def <op_name>(<args>, _builder=None):
@@ -292,7 +250,7 @@ class Libdevice(ExternLibrary):
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
import_str = "from . import core, extern\n"
import_str += "import os\n"
header_str = "LIBDEVICE_PATH = os.path.dirname(\n\tos.path.abspath(__file__)) + \"/libdevice.10.bc\"\n"
header_str = "LIBDEVICE_PATH = os.path.dirname(os.path.abspath(__file__)) + \"/libdevice.10.bc\"\n"
func_str = ""
for symbols in self._symbol_groups.values():
func_str += "@extern.extern\n"
@@ -325,10 +283,7 @@ class Libdevice(ExternLibrary):
class LLVMDisassembler:
_path: str
_ll_file: str
def __init__(self, path) -> None:
def __init__(self, path):
'''
Invoke llvm-dis to disassemble the given file.
@@ -337,28 +292,23 @@ class LLVMDisassembler:
self._path = path
self._ll_file = "/tmp/extern_lib.ll"
def disasm(self, lib_path: str) -> None:
def disasm(self, lib_path):
subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
stdout=subprocess.PIPE).communicate()
@property
def ll_file(self) -> str:
def ll_file(self):
return self._ll_file
@property
def path(self) -> str:
def path(self):
return self._path
extern_libs = ["libdevice"]
def build(
llvm_dis_path: str,
lib_path: str,
lib_name: str,
output_dir: str,
) -> None:
def build(llvm_dis_path, lib_path, lib_name, output_dir):
'''
Interface function to build the library file.
@@ -381,10 +331,10 @@ def build(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis")
parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library")
parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library")
parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/")
parser.add_argument("-llvm", dest="llvm_dis_path", help="path to llvm-dis", default="llvm-dis")
parser.add_argument("--lib-path", dest="lib_path", help="path to the extern library")
parser.add_argument("--lib-name", dest="lib_name", help="name of the extern library")
parser.add_argument("-o", dest="output_dir", help="output file path", default="/tmp/")
args = parser.parse_args()
build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir)

View File

@@ -7,7 +7,7 @@ Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html re
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
Using triton, you can simply call `tl.libdevice.asin`.
Using triton, you can simply call `tl.libdevice.asinf`.
triton automatically selects the correct underlying device function to invoke based on input and output types.
"""

View File

@@ -0,0 +1,14 @@
FROM rocm/pytorch:rocm5.2.3_ubuntu20.04_py3.7_pytorch_1.12.1
# build triton
RUN export TRITON_USE_ROCM=ON MI_GPU_ARCH=gfx90a
# Unit Tests
# to run unit tests
# 1. build this Dockerfile
# docker build --build-arg -f triton_rocm_20-52.Dockerfile -t triton_rocm52 .
# 2. run docker container
# docker run -it --rm --network=host --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --name triton --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri triton_rocm52:latest
# 3. run core unit tests on a rocm machine
# cd ~/triton/python
# pytest --verbose test/unit/language/test_core.py | tee test_core.log