Compare commits
59 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
46fd581b0a | ||
|
8cc448d92e | ||
|
9a9fabbba9 | ||
|
15886b5ffc | ||
|
d5830b4b6a | ||
|
bba1579485 | ||
|
cc6b5180c7 | ||
|
dfad6bdf36 | ||
|
f3bcbcfde6 | ||
|
7ec29a7453 | ||
|
4fb9d4904e | ||
|
4f3e2d6ed7 | ||
|
fecc7ce248 | ||
|
277b712284 | ||
|
d024f0cfb8 | ||
|
1811791665 | ||
|
9b3f2487b5 | ||
|
14730a2352 | ||
|
15683986cd | ||
|
48fcd8c987 | ||
|
8d9572bc63 | ||
|
ffb30cdc52 | ||
|
7fce2bc5f1 | ||
|
531ef18cb6 | ||
|
5f0d90db7e | ||
|
03ae41b310 | ||
|
bd61338b31 | ||
|
6e50f8b2c0 | ||
|
aa556d4f1b | ||
|
61e88efb23 | ||
|
ed9638801a | ||
|
8ecab462f6 | ||
|
648e4cfe89 | ||
|
abe0d3e1b1 | ||
|
4464dfcc18 | ||
|
0cae0168ec | ||
|
88d57ef9c9 | ||
|
39381d99f8 | ||
|
df925f7187 | ||
|
e84297ca79 | ||
|
61c85c18b2 | ||
|
da5c24ffcb | ||
|
09302f0106 | ||
|
9184b5cf65 | ||
|
8da4323514 | ||
|
eb89e9bdd9 | ||
|
56a06f7a06 | ||
|
6a31c43774 | ||
|
8785793445 | ||
|
d022f5cf2c | ||
|
4624fd4e1d | ||
|
41144f927f | ||
|
4d6d4c9431 | ||
|
32dbc08c05 | ||
|
4f21501def | ||
|
5c548fb57e | ||
|
fa4d0fd1ef | ||
|
406d03bfaf | ||
|
94d5c2e8b5 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -9,4 +9,5 @@ python/triton/_C/libtriton.pyd
|
||||
python/triton/_C/libtriton.so
|
||||
|
||||
.vscode
|
||||
.vs
|
||||
.vs
|
||||
log_*
|
@@ -3,6 +3,13 @@ include(ExternalProject)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
if(NOT TRITON_LLVM_BUILD_DIR)
|
||||
set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
|
||||
endif()
|
||||
|
||||
set(TRITON_USE_ROCM "$ENV{TRITON_USE_ROCM}")
|
||||
set(TRITON_ROCM_DEBUG "$ENV{TRITON_ROCM_DEBUG}")
|
||||
|
||||
project(triton)
|
||||
include(CTest)
|
||||
if(NOT WIN32)
|
||||
@@ -35,7 +42,11 @@ if(WIN32)
|
||||
add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
|
||||
if (TRITON_USE_ROCM)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-unused-result -Wno-attributes")
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
|
||||
endif()
|
||||
|
||||
|
||||
##########
|
||||
@@ -135,6 +146,13 @@ if(BUILD_PYTHON_MODULE)
|
||||
endif()
|
||||
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
|
||||
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
|
||||
if (TRITON_USE_ROCM)
|
||||
add_definitions(-DUSE_ROCM)
|
||||
endif()
|
||||
if (TRITON_ROCM_DEBUG)
|
||||
add_definitions(-DDEBUG_ROCM)
|
||||
endif()
|
||||
|
||||
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC})
|
||||
endif()
|
||||
|
||||
|
163
include/print_helper.h
Executable file
163
include/print_helper.h
Executable file
@@ -0,0 +1,163 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _PRINT_IR_H_
|
||||
#define _PRINT_IR_H_
|
||||
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/Support/CodeGen.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Support/TargetRegistry.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
#include "llvm/Analysis/TargetLibraryInfo.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include <iomanip>
|
||||
|
||||
#define PRINT_CURRENT_FUNCTION() std::cout << __FILE__ << ":" << __LINE__ << ":" << __FUNCTION__ << std::endl;
|
||||
|
||||
static int print_count = 0;
|
||||
|
||||
inline std::string return_current_time_and_date()
|
||||
{
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto in_time_t = std::chrono::system_clock::to_time_t(now);
|
||||
|
||||
std::stringstream ss;
|
||||
ss << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d--%I-%M-%S");
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void print_vector(std::vector<T> &vec, std::string name = "")
|
||||
{
|
||||
std::cout << name << ": ";
|
||||
for (auto v : vec)
|
||||
{
|
||||
std::cout << v << ", ";
|
||||
}
|
||||
|
||||
std::cout << '\b';
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
// dump llvm ir to tmp file
|
||||
inline std::string print_llvm_module(llvm::Module *llvm_module, bool print_to_cout = true)
|
||||
{
|
||||
std::cout << "\t" << "print_llvm_module" << std::endl;
|
||||
// get module as a string
|
||||
std::error_code ec;
|
||||
std::string mod_string;
|
||||
std::unique_ptr<llvm::raw_string_ostream> ir_ss(
|
||||
new llvm::raw_string_ostream(mod_string));
|
||||
llvm_module->print(*ir_ss, nullptr);
|
||||
|
||||
// print module
|
||||
if (print_to_cout)
|
||||
{
|
||||
if (!mod_string.empty())
|
||||
std::cout << "\t" << mod_string << std::endl;
|
||||
else
|
||||
std::cout << "\t" << llvm_module->getModuleIdentifier() << ": "
|
||||
<< "is empty" << std::endl;
|
||||
}
|
||||
|
||||
return mod_string;
|
||||
}
|
||||
|
||||
// dump llvm ir to tmp file
|
||||
inline void write_llvm_ir(llvm::Module *llvm_module, std::string filename = "", bool tracked = false)
|
||||
{
|
||||
|
||||
// get module string
|
||||
std::string module_string = print_llvm_module(llvm_module, false);
|
||||
|
||||
// get file name and path
|
||||
if (filename.empty())
|
||||
filename = llvm_module->getModuleIdentifier();
|
||||
std::string count_str = "";
|
||||
if (tracked)
|
||||
{
|
||||
count_str = "_" + std::to_string(print_count);
|
||||
}
|
||||
std::string ir_path = std::string("/tmp/") + filename + count_str + std::string(".ll");
|
||||
|
||||
// write file
|
||||
std::ofstream output_file(ir_path);
|
||||
output_file << module_string;
|
||||
output_file.close();
|
||||
|
||||
// increament counter
|
||||
if (tracked)
|
||||
{
|
||||
print_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
inline void print_triton_ir(triton::ir::module ir_ref, std::string name)
|
||||
{
|
||||
std::ofstream ir_out(std::string("/tmp/") + name + std::string("_") + return_current_time_and_date() + std::string(".ttir"));
|
||||
ir_out.flush();
|
||||
triton::ir::print(ir_ref, ir_out);
|
||||
ir_out.close();
|
||||
}
|
||||
|
||||
inline void print_triton_ir(std::string ir_ref, std::string name)
|
||||
{
|
||||
std::ofstream ir_out(std::string("/tmp/") + name + std::string("_") + return_current_time_and_date() + std::string(".ttir"));
|
||||
ir_out.flush();
|
||||
ir_out << ir_ref << std::endl;
|
||||
ir_out.close();
|
||||
}
|
||||
|
||||
inline std::string get_llvm_value_as_str(llvm::Value *llvm_value)
|
||||
{
|
||||
std::string value_str;
|
||||
llvm::raw_string_ostream rso(value_str);
|
||||
llvm_value->print(rso);
|
||||
return rso.str();
|
||||
}
|
||||
|
||||
inline void print_llvm_value(llvm::Value *llvm_value, std::string name = "")
|
||||
{
|
||||
if (llvm_value)
|
||||
std::cout << "\t" << name << ": " << get_llvm_value_as_str(llvm_value) << std::endl;
|
||||
else
|
||||
std::cout << "\t" << name << ": "
|
||||
<< "is nullptr" << std::endl;
|
||||
}
|
||||
|
||||
inline void print_llvm_type(llvm::Type *llvm_type, std::string name = "")
|
||||
{
|
||||
std::string type_str;
|
||||
llvm::raw_string_ostream rso(type_str);
|
||||
llvm_type->print(rso);
|
||||
std::cout << name << " type: " << rso.str() << std::endl;
|
||||
}
|
||||
|
||||
inline void print_llvm_value_type(llvm::Value *llvm_value, std::string name = "")
|
||||
{
|
||||
print_llvm_type(llvm_value->getType(), name);
|
||||
}
|
||||
|
||||
inline void write_ptx(std::string ptx_str)
|
||||
{
|
||||
std::ofstream file("/tmp/kernel.ptx");
|
||||
file << ptx_str;
|
||||
}
|
||||
#endif
|
@@ -36,6 +36,7 @@ namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
class nvidia_cu_target;
|
||||
class amd_cl_target;
|
||||
|
||||
class target {
|
||||
public:
|
||||
@@ -49,7 +50,12 @@ public:
|
||||
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual unsigned guaranteed_alignment() = 0;
|
||||
#ifdef USE_ROCM
|
||||
amd_cl_target* as_nvidia();
|
||||
amd_cl_target* as_amd();
|
||||
#else
|
||||
nvidia_cu_target* as_nvidia();
|
||||
#endif
|
||||
bool is_gpu() const;
|
||||
|
||||
private:
|
||||
@@ -67,6 +73,7 @@ public:
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
unsigned guaranteed_alignment() { return 16; }
|
||||
int sm() { return 0; } // treat as if old CUDA device
|
||||
};
|
||||
|
||||
class nvidia_cu_target: public target {
|
||||
|
@@ -11,7 +11,6 @@
|
||||
#include "triton/external/CUDA/nvml.h"
|
||||
|
||||
//// HIP backend
|
||||
//#define __HIP_PLATFORM_AMD__
|
||||
#include "triton/external/hip.h"
|
||||
|
||||
//Exceptions
|
||||
@@ -183,7 +182,8 @@ public:
|
||||
static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart, hipEvent_t hEnd);
|
||||
static hipError_t hipEventRecord(hipEvent_t hEvent, hipStream_t hStream);
|
||||
static hipError_t hipEventDestroy(hipEvent_t hEvent);
|
||||
|
||||
// error handling
|
||||
static hipError_t hipGetLastError(void);
|
||||
|
||||
|
||||
private:
|
||||
@@ -309,6 +309,8 @@ private:
|
||||
static void* hipEventElapsedTime_;
|
||||
static void* hipEventRecord_;
|
||||
static void* hipEventDestroy_;
|
||||
// error handling
|
||||
static void* hipGetLastError_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -13,8 +13,12 @@ std::string path_to_ptxas(int& version);
|
||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
|
||||
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
|
||||
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
|
||||
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
|
||||
std::tuple<std::string, std::string> llir_to_amdgcn(llvm::Module* module, const std::string& proc);
|
||||
hipModule_t amdgpu_to_hipmodule(const std::string& path);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#define STRINGIFY_HELPER(X) #X
|
||||
#define STRINGIFY(X) STRINGIFY_HELPER(X)
|
||||
|
||||
|
319
include/triton/external/hip.h
vendored
319
include/triton/external/hip.h
vendored
@@ -1,13 +1,35 @@
|
||||
/*
|
||||
* @brief hipError_t
|
||||
* @enum
|
||||
* @ingroup Enumerations
|
||||
*/
|
||||
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
|
||||
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
|
||||
Copyright (c) 2015 - 2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef HIP_H
|
||||
#define HIP_H
|
||||
|
||||
// Ignoring error-code return values from hip APIs is discouraged. On C++17,
|
||||
// we can make that yield a warning
|
||||
#if __cplusplus >= 201703L
|
||||
#define __HIP_NODISCARD [[nodiscard]]
|
||||
#else
|
||||
#define __HIP_NODISCARD
|
||||
#endif
|
||||
|
||||
/*
|
||||
* @brief hipError_t
|
||||
@@ -17,9 +39,7 @@
|
||||
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
|
||||
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
typedef enum hipError_t {
|
||||
typedef enum __HIP_NODISCARD hipError_t {
|
||||
hipSuccess = 0, ///< Successful completion.
|
||||
hipErrorInvalidValue = 1, ///< One or more of the parameters passed to the API call is NULL
|
||||
///< or not in an acceptable range.
|
||||
@@ -73,6 +93,7 @@ typedef enum hipError_t {
|
||||
hipErrorInvalidHandle = 400,
|
||||
// Deprecated
|
||||
hipErrorInvalidResourceHandle = 400, ///< Resource handle (hipEvent_t or hipStream_t) invalid.
|
||||
hipErrorIllegalState = 401, ///< Resource required is not in a valid state to perform operation.
|
||||
hipErrorNotFound = 500,
|
||||
hipErrorNotReady = 600, ///< Indicates that asynchronous operations enqueued earlier are not
|
||||
///< ready. This is not actually an error, but is used to distinguish
|
||||
@@ -86,6 +107,7 @@ typedef enum hipError_t {
|
||||
hipErrorPeerAccessNotEnabled =
|
||||
705, ///< Peer access was never enabled from the current device.
|
||||
hipErrorSetOnActiveProcess = 708,
|
||||
hipErrorContextIsDestroyed = 709,
|
||||
hipErrorAssert = 710, ///< Produced when the kernel calls assert.
|
||||
hipErrorHostMemoryAlreadyRegistered =
|
||||
712, ///< Produced when trying to lock a page-locked memory.
|
||||
@@ -98,6 +120,32 @@ typedef enum hipError_t {
|
||||
///< that was launched via cooperative launch APIs exceeds the maximum number of
|
||||
///< allowed blocks for the current device
|
||||
hipErrorNotSupported = 801, ///< Produced when the hip API is not supported/implemented
|
||||
hipErrorStreamCaptureUnsupported = 900, ///< The operation is not permitted when the stream
|
||||
///< is capturing.
|
||||
hipErrorStreamCaptureInvalidated = 901, ///< The current capture sequence on the stream
|
||||
///< has been invalidated due to a previous error.
|
||||
hipErrorStreamCaptureMerge = 902, ///< The operation would have resulted in a merge of
|
||||
///< two independent capture sequences.
|
||||
hipErrorStreamCaptureUnmatched = 903, ///< The capture was not initiated in this stream.
|
||||
hipErrorStreamCaptureUnjoined = 904, ///< The capture sequence contains a fork that was not
|
||||
///< joined to the primary stream.
|
||||
hipErrorStreamCaptureIsolation = 905, ///< A dependency would have been created which crosses
|
||||
///< the capture sequence boundary. Only implicit
|
||||
///< in-stream ordering dependencies are allowed
|
||||
///< to cross the boundary
|
||||
hipErrorStreamCaptureImplicit = 906, ///< The operation would have resulted in a disallowed
|
||||
///< implicit dependency on a current capture sequence
|
||||
///< from hipStreamLegacy.
|
||||
hipErrorCapturedEvent = 907, ///< The operation is not permitted on an event which was last
|
||||
///< recorded in a capturing stream.
|
||||
hipErrorStreamCaptureWrongThread = 908, ///< A stream capture sequence not initiated with
|
||||
///< the hipStreamCaptureModeRelaxed argument to
|
||||
///< hipStreamBeginCapture was passed to
|
||||
///< hipStreamEndCapture in a different thread.
|
||||
hipErrorGraphExecUpdateFailure = 910, ///< This error indicates that the graph update
|
||||
///< not performed because it included changes which
|
||||
///< violated constraintsspecific to instantiated graph
|
||||
///< update.
|
||||
hipErrorUnknown = 999, //< Unknown error.
|
||||
// HSA Runtime Error Codes start here.
|
||||
hipErrorRuntimeMemory = 1052, ///< HSA runtime memory call returned error. Typically not seen
|
||||
@@ -107,35 +155,154 @@ typedef enum hipError_t {
|
||||
hipErrorTbd ///< Marker that more error codes are needed.
|
||||
} hipError_t;
|
||||
|
||||
#undef __HIP_NODISCARD
|
||||
|
||||
/*
|
||||
* @brief hipDeviceAttribute_t
|
||||
* @enum
|
||||
* @ingroup Enumerations
|
||||
*/
|
||||
typedef enum hipDeviceAttribute_t {
|
||||
hipDeviceAttributeCudaCompatibleBegin = 0,
|
||||
|
||||
hipDeviceAttributeEccEnabled = hipDeviceAttributeCudaCompatibleBegin, ///< Whether ECC support is enabled.
|
||||
hipDeviceAttributeAccessPolicyMaxWindowSize, ///< Cuda only. The maximum size of the window policy in bytes.
|
||||
hipDeviceAttributeAsyncEngineCount, ///< Cuda only. Asynchronous engines number.
|
||||
hipDeviceAttributeCanMapHostMemory, ///< Whether host memory can be mapped into device address space
|
||||
hipDeviceAttributeCanUseHostPointerForRegisteredMem,///< Cuda only. Device can access host registered memory
|
||||
///< at the same virtual address as the CPU
|
||||
hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz.
|
||||
hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in.
|
||||
hipDeviceAttributeComputePreemptionSupported, ///< Cuda only. Device supports Compute Preemption.
|
||||
hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple kernels concurrently.
|
||||
hipDeviceAttributeConcurrentManagedAccess, ///< Device can coherently access managed memory concurrently with the CPU
|
||||
hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch
|
||||
hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative launch on multiple devices
|
||||
hipDeviceAttributeDeviceOverlap, ///< Cuda only. Device can concurrently copy memory and execute a kernel.
|
||||
///< Deprecated. Use instead asyncEngineCount.
|
||||
hipDeviceAttributeDirectManagedMemAccessFromHost, ///< Host can directly access managed memory on
|
||||
///< the device without migration
|
||||
hipDeviceAttributeGlobalL1CacheSupported, ///< Cuda only. Device supports caching globals in L1
|
||||
hipDeviceAttributeHostNativeAtomicSupported, ///< Cuda only. Link between the device and the host supports native atomic operations
|
||||
hipDeviceAttributeIntegrated, ///< Device is integrated GPU
|
||||
hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices.
|
||||
hipDeviceAttributeKernelExecTimeout, ///< Run time limit for kernels executed on the device
|
||||
hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device doesn't have L2 cache.
|
||||
hipDeviceAttributeLocalL1CacheSupported, ///< caching locals in L1 is supported
|
||||
hipDeviceAttributeLuid, ///< Cuda only. 8-byte locally unique identifier in 8 bytes. Undefined on TCC and non-Windows platforms
|
||||
hipDeviceAttributeLuidDeviceNodeMask, ///< Cuda only. Luid device node mask. Undefined on TCC and non-Windows platforms
|
||||
hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability version number.
|
||||
hipDeviceAttributeManagedMemory, ///< Device supports allocating managed memory on this system
|
||||
hipDeviceAttributeMaxBlocksPerMultiProcessor, ///< Cuda only. Max block size per multiprocessor
|
||||
hipDeviceAttributeMaxBlockDimX, ///< Max block size in width.
|
||||
hipDeviceAttributeMaxBlockDimY, ///< Max block size in height.
|
||||
hipDeviceAttributeMaxBlockDimZ, ///< Max block size in depth.
|
||||
hipDeviceAttributeMaxGridDimX, ///< Max grid size in width.
|
||||
hipDeviceAttributeMaxGridDimY, ///< Max grid size in height.
|
||||
hipDeviceAttributeMaxGridDimZ, ///< Max grid size in depth.
|
||||
hipDeviceAttributeMaxSurface1D, ///< Maximum size of 1D surface.
|
||||
hipDeviceAttributeMaxSurface1DLayered, ///< Cuda only. Maximum dimensions of 1D layered surface.
|
||||
hipDeviceAttributeMaxSurface2D, ///< Maximum dimension (width, height) of 2D surface.
|
||||
hipDeviceAttributeMaxSurface2DLayered, ///< Cuda only. Maximum dimensions of 2D layered surface.
|
||||
hipDeviceAttributeMaxSurface3D, ///< Maximum dimension (width, height, depth) of 3D surface.
|
||||
hipDeviceAttributeMaxSurfaceCubemap, ///< Cuda only. Maximum dimensions of Cubemap surface.
|
||||
hipDeviceAttributeMaxSurfaceCubemapLayered, ///< Cuda only. Maximum dimension of Cubemap layered surface.
|
||||
hipDeviceAttributeMaxTexture1DWidth, ///< Maximum size of 1D texture.
|
||||
hipDeviceAttributeMaxTexture1DLayered, ///< Cuda only. Maximum dimensions of 1D layered texture.
|
||||
hipDeviceAttributeMaxTexture1DLinear, ///< Maximum number of elements allocatable in a 1D linear texture.
|
||||
///< Use cudaDeviceGetTexture1DLinearMaxWidth() instead on Cuda.
|
||||
hipDeviceAttributeMaxTexture1DMipmap, ///< Cuda only. Maximum size of 1D mipmapped texture.
|
||||
hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D texture.
|
||||
hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension hight of 2D texture.
|
||||
hipDeviceAttributeMaxTexture2DGather, ///< Cuda only. Maximum dimensions of 2D texture if gather operations performed.
|
||||
hipDeviceAttributeMaxTexture2DLayered, ///< Cuda only. Maximum dimensions of 2D layered texture.
|
||||
hipDeviceAttributeMaxTexture2DLinear, ///< Cuda only. Maximum dimensions (width, height, pitch) of 2D textures bound to pitched memory.
|
||||
hipDeviceAttributeMaxTexture2DMipmap, ///< Cuda only. Maximum dimensions of 2D mipmapped texture.
|
||||
hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D texture.
|
||||
hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimension height of 3D texture.
|
||||
hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimension depth of 3D texture.
|
||||
hipDeviceAttributeMaxTexture3DAlt, ///< Cuda only. Maximum dimensions of alternate 3D texture.
|
||||
hipDeviceAttributeMaxTextureCubemap, ///< Cuda only. Maximum dimensions of Cubemap texture
|
||||
hipDeviceAttributeMaxTextureCubemapLayered, ///< Cuda only. Maximum dimensions of Cubemap layered texture.
|
||||
hipDeviceAttributeMaxThreadsDim, ///< Maximum dimension of a block
|
||||
hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per block.
|
||||
hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads per multiprocessor.
|
||||
hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory copies
|
||||
hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits.
|
||||
hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in kilohertz.
|
||||
hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability version number.
|
||||
hipDeviceAttributeMultiGpuBoardGroupID, ///< Cuda only. Unique ID of device group on the same multi-GPU board
|
||||
hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the device.
|
||||
hipDeviceAttributeName, ///< Device name.
|
||||
hipDeviceAttributePageableMemoryAccess, ///< Device supports coherently accessing pageable memory
|
||||
///< without calling hipHostRegister on it
|
||||
hipDeviceAttributePageableMemoryAccessUsesHostPageTables, ///< Device accesses pageable memory via the host's page tables
|
||||
hipDeviceAttributePciBusId, ///< PCI Bus ID.
|
||||
hipDeviceAttributePciDeviceId, ///< PCI Device ID.
|
||||
hipDeviceAttributePciDomainID, ///< PCI Domain ID.
|
||||
hipDeviceAttributePersistingL2CacheMaxSize, ///< Cuda11 only. Maximum l2 persisting lines capacity in bytes
|
||||
hipDeviceAttributeMaxRegistersPerBlock, ///< 32-bit registers available to a thread block. This number is shared
|
||||
///< by all thread blocks simultaneously resident on a multiprocessor.
|
||||
hipDeviceAttributeMaxRegistersPerMultiprocessor, ///< 32-bit registers available per block.
|
||||
hipDeviceAttributeReservedSharedMemPerBlock, ///< Cuda11 only. Shared memory reserved by CUDA driver per block.
|
||||
hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory available per block in bytes.
|
||||
hipDeviceAttributeSharedMemPerBlockOptin, ///< Cuda only. Maximum shared memory per block usable by special opt in.
|
||||
hipDeviceAttributeSharedMemPerMultiprocessor, ///< Cuda only. Shared memory available per multiprocessor.
|
||||
hipDeviceAttributeSingleToDoublePrecisionPerfRatio, ///< Cuda only. Performance ratio of single precision to double precision.
|
||||
hipDeviceAttributeStreamPrioritiesSupported, ///< Cuda only. Whether to support stream priorities.
|
||||
hipDeviceAttributeSurfaceAlignment, ///< Cuda only. Alignment requirement for surfaces
|
||||
hipDeviceAttributeTccDriver, ///< Cuda only. Whether device is a Tesla device using TCC driver
|
||||
hipDeviceAttributeTextureAlignment, ///< Alignment requirement for textures
|
||||
hipDeviceAttributeTexturePitchAlignment, ///< Pitch alignment requirement for 2D texture references bound to pitched memory;
|
||||
hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes.
|
||||
hipDeviceAttributeTotalGlobalMem, ///< Global memory available on devicice.
|
||||
hipDeviceAttributeUnifiedAddressing, ///< Cuda only. An unified address space shared with the host.
|
||||
hipDeviceAttributeUuid, ///< Cuda only. Unique ID in 16 byte.
|
||||
hipDeviceAttributeWarpSize, ///< Warp size in threads.
|
||||
hipDeviceAttributeMemoryPoolsSupported, ///< Device supports HIP Stream Ordered Memory Allocator
|
||||
|
||||
hipDeviceAttributeCudaCompatibleEnd = 9999,
|
||||
hipDeviceAttributeAmdSpecificBegin = 10000,
|
||||
|
||||
hipDeviceAttributeClockInstructionRate = hipDeviceAttributeAmdSpecificBegin, ///< Frequency in khz of the timer used by the device-side "clock*"
|
||||
hipDeviceAttributeArch, ///< Device architecture
|
||||
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory PerMultiprocessor.
|
||||
hipDeviceAttributeGcnArch, ///< Device gcn architecture
|
||||
hipDeviceAttributeGcnArchName, ///< Device gcnArch name in 256 bytes
|
||||
hipDeviceAttributeHdpMemFlushCntl, ///< Address of the HDP_MEM_COHERENCY_FLUSH_CNTL register
|
||||
hipDeviceAttributeHdpRegFlushCntl, ///< Address of the HDP_REG_COHERENCY_FLUSH_CNTL register
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc, ///< Supports cooperative launch on multiple
|
||||
///< devices with unmatched functions
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim, ///< Supports cooperative launch on multiple
|
||||
///< devices with unmatched grid dimensions
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim, ///< Supports cooperative launch on multiple
|
||||
///< devices with unmatched block dimensions
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem, ///< Supports cooperative launch on multiple
|
||||
///< devices with unmatched shared memories
|
||||
hipDeviceAttributeIsLargeBar, ///< Whether it is LargeBar
|
||||
hipDeviceAttributeAsicRevision, ///< Revision of the GPU in this device
|
||||
hipDeviceAttributeCanUseStreamWaitValue, ///< '1' if Device supports hipStreamWaitValue32() and
|
||||
///< hipStreamWaitValue64(), '0' otherwise.
|
||||
hipDeviceAttributeImageSupport, ///< '1' if Device supports image, '0' otherwise.
|
||||
hipDeviceAttributePhysicalMultiProcessorCount, ///< All available physical compute
|
||||
///< units for the device
|
||||
hipDeviceAttributeFineGrainSupport, ///< '1' if Device supports fine grain, '0' otherwise
|
||||
|
||||
hipDeviceAttributeAmdSpecificEnd = 19999,
|
||||
hipDeviceAttributeVendorSpecificBegin = 20000,
|
||||
// Extended attributes for vendors
|
||||
} hipDeviceAttribute_t;
|
||||
|
||||
#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01)
|
||||
#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02)
|
||||
#define HIP_LAUNCH_PARAM_END ((void*)0x03)
|
||||
// API-visible structures
|
||||
typedef struct ihipCtx_t* hipCtx_t;
|
||||
|
||||
// Note many APIs also use integer deviceIds as an alternative to the device pointer:
|
||||
typedef int hipDevice_t;
|
||||
|
||||
typedef enum hipDeviceP2PAttr {
|
||||
hipDevP2PAttrPerformanceRank = 0,
|
||||
hipDevP2PAttrAccessSupported,
|
||||
hipDevP2PAttrNativeAtomicSupported,
|
||||
hipDevP2PAttrHipArrayAccessSupported
|
||||
} hipDeviceP2PAttr;
|
||||
|
||||
typedef struct ihipStream_t* hipStream_t;
|
||||
|
||||
#define hipIpcMemLazyEnablePeerAccess 0
|
||||
|
||||
#define HIP_IPC_HANDLE_SIZE 64
|
||||
|
||||
typedef struct hipIpcMemHandle_st {
|
||||
char reserved[HIP_IPC_HANDLE_SIZE];
|
||||
} hipIpcMemHandle_t;
|
||||
|
||||
typedef struct hipIpcEventHandle_st {
|
||||
char reserved[HIP_IPC_HANDLE_SIZE];
|
||||
} hipIpcEventHandle_t;
|
||||
|
||||
typedef struct ihipModule_t* hipModule_t;
|
||||
|
||||
typedef struct ihipModuleSymbol_t* hipFunction_t;
|
||||
|
||||
typedef struct hipFuncAttributes {
|
||||
@@ -150,91 +317,8 @@ typedef struct hipFuncAttributes {
|
||||
int ptxVersion;
|
||||
size_t sharedSizeBytes;
|
||||
} hipFuncAttributes;
|
||||
|
||||
typedef struct ihipEvent_t* hipEvent_t;
|
||||
|
||||
/*
|
||||
* @brief hipDeviceAttribute_t
|
||||
* @enum
|
||||
* @ingroup Enumerations
|
||||
*/
|
||||
typedef enum hipDeviceAttribute_t {
|
||||
hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per block.
|
||||
hipDeviceAttributeMaxBlockDimX, ///< Maximum x-dimension of a block.
|
||||
hipDeviceAttributeMaxBlockDimY, ///< Maximum y-dimension of a block.
|
||||
hipDeviceAttributeMaxBlockDimZ, ///< Maximum z-dimension of a block.
|
||||
hipDeviceAttributeMaxGridDimX, ///< Maximum x-dimension of a grid.
|
||||
hipDeviceAttributeMaxGridDimY, ///< Maximum y-dimension of a grid.
|
||||
hipDeviceAttributeMaxGridDimZ, ///< Maximum z-dimension of a grid.
|
||||
hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory available per block in
|
||||
///< bytes.
|
||||
hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes.
|
||||
hipDeviceAttributeWarpSize, ///< Warp size in threads.
|
||||
hipDeviceAttributeMaxRegistersPerBlock, ///< Maximum number of 32-bit registers available to a
|
||||
///< thread block. This number is shared by all thread
|
||||
///< blocks simultaneously resident on a
|
||||
///< multiprocessor.
|
||||
hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz.
|
||||
hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in kilohertz.
|
||||
hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits.
|
||||
hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the device.
|
||||
hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in.
|
||||
hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device doesn't have L2
|
||||
///< cache.
|
||||
hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads per
|
||||
///< multiprocessor.
|
||||
hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability version number.
|
||||
hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability version number.
|
||||
hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple kernels
|
||||
///< concurrently.
|
||||
hipDeviceAttributePciBusId, ///< PCI Bus ID.
|
||||
hipDeviceAttributePciDeviceId, ///< PCI Device ID.
|
||||
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory Per
|
||||
///< Multiprocessor.
|
||||
hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices.
|
||||
hipDeviceAttributeIntegrated, ///< iGPU
|
||||
hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch
|
||||
hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative launch on multiple devices
|
||||
hipDeviceAttributeMaxTexture1DWidth, ///< Maximum number of elements in 1D images
|
||||
hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D images in image elements
|
||||
hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension height of 2D images in image elements
|
||||
hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D images in image elements
|
||||
hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimensions height of 3D images in image elements
|
||||
hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimensions depth of 3D images in image elements
|
||||
|
||||
hipDeviceAttributeHdpMemFlushCntl, ///< Address of the HDP_MEM_COHERENCY_FLUSH_CNTL register
|
||||
hipDeviceAttributeHdpRegFlushCntl, ///< Address of the HDP_REG_COHERENCY_FLUSH_CNTL register
|
||||
|
||||
hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory copies
|
||||
hipDeviceAttributeTextureAlignment, ///<Alignment requirement for textures
|
||||
hipDeviceAttributeTexturePitchAlignment, ///<Pitch alignment requirement for 2D texture references bound to pitched memory;
|
||||
hipDeviceAttributeKernelExecTimeout, ///<Run time limit for kernels executed on the device
|
||||
hipDeviceAttributeCanMapHostMemory, ///<Device can map host memory into device address space
|
||||
hipDeviceAttributeEccEnabled, ///<Device has ECC support enabled
|
||||
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc, ///< Supports cooperative launch on multiple
|
||||
///devices with unmatched functions
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim, ///< Supports cooperative launch on multiple
|
||||
///devices with unmatched grid dimensions
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim, ///< Supports cooperative launch on multiple
|
||||
///devices with unmatched block dimensions
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem, ///< Supports cooperative launch on multiple
|
||||
///devices with unmatched shared memories
|
||||
hipDeviceAttributeAsicRevision, ///< Revision of the GPU in this device
|
||||
hipDeviceAttributeManagedMemory, ///< Device supports allocating managed memory on this system
|
||||
hipDeviceAttributeDirectManagedMemAccessFromHost, ///< Host can directly access managed memory on
|
||||
/// the device without migration
|
||||
hipDeviceAttributeConcurrentManagedAccess, ///< Device can coherently access managed memory
|
||||
/// concurrently with the CPU
|
||||
hipDeviceAttributePageableMemoryAccess, ///< Device supports coherently accessing pageable memory
|
||||
/// without calling hipHostRegister on it
|
||||
hipDeviceAttributePageableMemoryAccessUsesHostPageTables, ///< Device accesses pageable memory via
|
||||
/// the host's page tables
|
||||
hipDeviceAttributeCanUseStreamWaitValue ///< '1' if Device supports hipStreamWaitValue32() and
|
||||
///< hipStreamWaitValue64() , '0' otherwise.
|
||||
|
||||
} hipDeviceAttribute_t;
|
||||
|
||||
typedef void* hipDeviceptr_t;
|
||||
|
||||
/*
|
||||
@@ -262,7 +346,6 @@ typedef enum hipJitOption {
|
||||
hipJitOptionFastCompile,
|
||||
hipJitOptionNumOptions
|
||||
} hipJitOption;
|
||||
|
||||
/**
|
||||
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
|
||||
*/
|
||||
@@ -271,7 +354,6 @@ typedef enum hipFuncAttribute {
|
||||
hipFuncAttributePreferredSharedMemoryCarveout = 9,
|
||||
hipFuncAttributeMax
|
||||
} hipFuncAttribute;
|
||||
|
||||
/**
|
||||
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
|
||||
*/
|
||||
@@ -282,7 +364,4 @@ typedef enum hipFuncCache_t {
|
||||
hipFuncCachePreferEqual, ///< prefer equal size L1 cache and shared memory
|
||||
} hipFuncCache_t;
|
||||
|
||||
|
||||
#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01)
|
||||
#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02)
|
||||
#define HIP_LAUNCH_PARAM_END ((void*)0x03)
|
||||
#endif
|
||||
|
@@ -45,6 +45,7 @@ public:
|
||||
value *get_int64(uint64_t val);
|
||||
value *get_float16(float val);
|
||||
value *get_float32(float val);
|
||||
value *get_float64(float val);
|
||||
value *get_range(int32_t lo, int32_t hi);
|
||||
// Types
|
||||
type *get_void_ty();
|
||||
|
@@ -196,7 +196,7 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
tensor_core_type_ = get_mma_type(dot);
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
if(tgt->as_nvidia()->sm() < 80){
|
||||
if(tgt->as_nvidia() && tgt->as_nvidia()->sm() < 80){
|
||||
fpw_ = {2, 2, 1};
|
||||
auto ord_a = layout_a->get_order();
|
||||
auto ord_b = layout_b->get_order();
|
||||
|
@@ -79,11 +79,13 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
|
||||
ir::module& ir, llvm::LLVMContext& ctx, codegen::target* target,
|
||||
int num_warps, int num_stages, int& shared_static,
|
||||
const ExternLibMap& extern_lib_map) {
|
||||
std::cout << "pass.cc: add_passes_to_emit_bin" << std::endl;
|
||||
// generate llvm code
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
||||
// optimizations
|
||||
bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
|
||||
// bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
|
||||
bool has_sm80 = false;
|
||||
// create passes
|
||||
codegen::analysis::align align;
|
||||
codegen::transform::inliner inliner;
|
||||
|
@@ -16,7 +16,13 @@
|
||||
#include "triton/ir/utils.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
#ifdef USE_ROCM
|
||||
#include "llvm/IR/IntrinsicsAMDGPU.h"
|
||||
#else
|
||||
#include "llvm/IR/IntrinsicsNVPTX.h"
|
||||
#endif
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Attributes.h"
|
||||
#include "llvm/IR/InlineAsm.h"
|
||||
@@ -91,6 +97,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
#define f16_ty builder_->getHalfTy()
|
||||
#define bf16_ty builder_->getInt16Ty()
|
||||
#define f32_ty builder_->getFloatTy()
|
||||
#define f64_ty builder_->getDoubleTy()
|
||||
#define i1_ty builder_->getInt1Ty()
|
||||
#define i8_ty builder_->getInt8Ty()
|
||||
#define i16_ty builder_->getInt16Ty()
|
||||
@@ -410,48 +417,44 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
Value *lhs = vals_[x->get_operand(0)][idx];
|
||||
Value *rhs = vals_[x->get_operand(1)][idx];
|
||||
// manually select bf16 bin op
|
||||
if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) {
|
||||
assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty());
|
||||
if (x->get_op() == tt::FAdd) { // a + b = a * 1.0 + b
|
||||
if (x->get_op() == tt::FAdd) {
|
||||
InlineAsm *bf16_add_asm =
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
"{ .reg .b16 c; \n\t"
|
||||
" mov.b16 c, 0x3f80U; \n\t" // 1.0
|
||||
" fma.rn.bf16 $0, $1, c, $2; } \n\t",
|
||||
"=h,h,h", false);
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
"v_lshlrev_b32 $1, 16, $1 "
|
||||
"v_lshlrev_b32 $2, 16, $2 "
|
||||
"v_add_f32 $0, $1, $2 "
|
||||
"v_lshrrev_b32 $0, 16, $0 ",
|
||||
"=v,v,v", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs});
|
||||
} else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a
|
||||
InlineAsm *bf16_sub_asm =
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
" { .reg .b16 c; \n\t"
|
||||
" mov.b16 c, 0xbf80U; \n\t" // -1.0
|
||||
" fma.rn.bf16 $0, $2, c, $1;} \n\t",
|
||||
"=h,h,h", false);
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
"v_lshlrev_b32 $1, 16, $1 "
|
||||
"v_lshlrev_b32 $2, 16, $2 "
|
||||
"v_sub_f32 $0, $1, $2 "
|
||||
"v_lshrrev_b32 $0, 16, $0 ",
|
||||
"=v,v,v", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs});
|
||||
} else if (x->get_op() == tt::FMul) { // a * b = a*b + 0
|
||||
InlineAsm *bf16_mul_asm =
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
" { .reg .b16 c; \n\t"
|
||||
" mov.b16 c, 0x8000U; \n\t" // 0.0
|
||||
" fma.rn.bf16 $0, $1, $2, c;} \n\t",
|
||||
"=h,h,h", false);
|
||||
"v_lshlrev_b32 $1, 16, $1 "
|
||||
"v_lshlrev_b32 $2, 16, $2 "
|
||||
"v_mul_f32 $0, $1, $2 "
|
||||
"v_lshrrev_b32 $0, 16, $0 ",
|
||||
"=v,v,v", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs});
|
||||
} else
|
||||
throw std::runtime_error("invalid bin op for bf16");
|
||||
} else { // not bf16
|
||||
}
|
||||
else {
|
||||
auto op = cvt(x->get_op());
|
||||
if(op == ll::Add)
|
||||
vals_[x][idx] = add(lhs, rhs);
|
||||
else if(op == ll::Mul)
|
||||
vals_[x][idx] = mul(lhs, rhs);
|
||||
else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() &&
|
||||
x->get_type()->get_scalar_ty()->is_fp32_ty()){
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false),
|
||||
" div.full.f32 $0, $1, $2;", "=r,r,r", false);
|
||||
vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs});
|
||||
|
||||
}
|
||||
else
|
||||
vals_[x][idx] = bin_op(op, lhs, rhs);
|
||||
}
|
||||
@@ -754,7 +757,7 @@ Value* generator::bf16_to_fp32(Value *in0){
|
||||
}
|
||||
|
||||
Value* generator::fp32_to_bf16(Value *in0){
|
||||
if(tgt_->as_nvidia()->sm() >= 80){
|
||||
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80){
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(bf16_ty, {f32_ty}, false),
|
||||
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
|
||||
return call(ptx, {in0});
|
||||
@@ -1120,6 +1123,22 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
ir::value *op = x->get_pointer_operand();
|
||||
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
|
||||
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// code generation
|
||||
auto idxs = idxs_.at(x);
|
||||
for(size_t i = 0; i <idxs.size(); i += 1){
|
||||
indices_t idx = idxs[i];
|
||||
// pointer value
|
||||
Value *ptr = vals_[op][idx];
|
||||
|
||||
// create load
|
||||
Value *_ret = builder_->CreateLoad(ty, ptr);
|
||||
|
||||
// upload to global vals map
|
||||
vals_[x][idx] = _ret;
|
||||
}
|
||||
#else
|
||||
// compute vector width
|
||||
size_t vec = 1;
|
||||
bool is_mma_first_row = false;
|
||||
@@ -1298,6 +1317,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
for(size_t ii = 0; ii < vec; ii++)
|
||||
vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
@@ -1316,6 +1336,23 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
// operands
|
||||
ir::value *ptr_op = x->get_pointer_operand();
|
||||
ir::value *val_op = x->get_value_operand();
|
||||
#ifdef USE_ROCM
|
||||
auto idxs = idxs_.at(val_op);
|
||||
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
|
||||
|
||||
for (size_t i = 0; i < idxs.size(); i += 1)
|
||||
{
|
||||
auto idx = idxs[i];
|
||||
// pointer
|
||||
Value *ptr = vals_[ptr_op][idx];
|
||||
|
||||
// value
|
||||
Value *val = vals_.at(val_op)[idxs[i]];
|
||||
|
||||
// store value at pointer
|
||||
store(val, ptr);
|
||||
}
|
||||
#else
|
||||
ir::value *msk_op = nullptr;
|
||||
if(auto* msk_st = dynamic_cast<ir::masked_store_inst*>(x))
|
||||
msk_op = msk_st->get_mask_operand();
|
||||
@@ -1431,6 +1468,7 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
args.push_back(policies_.at(x->get_eviction_policy()));
|
||||
call(_asm, args);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) {
|
||||
visit_store_inst(x);
|
||||
@@ -1549,7 +1587,12 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *ex2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::exp2, tys);
|
||||
#else
|
||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
|
||||
#endif
|
||||
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
|
||||
// Value *ex2arg = vals_[x->get_operand(0)][idx];
|
||||
@@ -1563,7 +1606,11 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
void generator::visit_cos_inst(ir::cos_inst* x){
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *cos = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::cos, tys);
|
||||
#else
|
||||
InlineAsm *cos = InlineAsm::get(fn_ty, "cos.approx.f32 $0, $0;", "=f,0", false);
|
||||
#endif
|
||||
for(auto idx: idxs_.at(x)){
|
||||
vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
}
|
||||
@@ -1589,7 +1636,11 @@ void generator::visit_umulhi_inst(ir::umulhi_inst* x){
|
||||
void generator::visit_sin_inst(ir::sin_inst* x){
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *sin = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::sin, tys);
|
||||
#else
|
||||
InlineAsm *sin = InlineAsm::get(fn_ty, "sin.approx.f32 $0, $0;", "=f,0", false);
|
||||
#endif
|
||||
for(auto idx: idxs_.at(x)){
|
||||
vals_[x][idx] = call(sin, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
}
|
||||
@@ -1602,7 +1653,11 @@ void generator::visit_log_inst(ir::log_inst* x){
|
||||
Constant *rcplog2e = ConstantFP::get(f32_ty, 0.6931471805599453);
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *lg2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::log2, tys);
|
||||
#else
|
||||
InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false);
|
||||
#endif
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value *lg2arg = call(lg2, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
vals_[x][idx] = fmul(lg2arg, rcplog2e);
|
||||
@@ -1612,6 +1667,35 @@ void generator::visit_log_inst(ir::log_inst* x){
|
||||
/**
|
||||
* \brief Code Generation for `atomic_cas`
|
||||
*/
|
||||
#if defined(USE_ROCM)
|
||||
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *pred = icmp_eq(tid, i32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
add_barrier();
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
cond_br(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
Value *cas_ptr = vals_[cas->get_operand(0)][{}];
|
||||
Value *cas_cmp = vals_[cas->get_operand(1)][{}];
|
||||
Value *cas_val = vals_[cas->get_operand(2)][{}];
|
||||
Value *old = atomic_cmp_xchg(cas_ptr, cas_cmp, cas_val, MaybeAlign(), AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
|
||||
old = extract_val(old, std::vector<unsigned>{0});
|
||||
Value *atom_ptr;
|
||||
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), "");
|
||||
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
|
||||
store(old, atom_ptr);
|
||||
br(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
add_barrier();
|
||||
vals_[cas][{}] = load(atom_ptr);
|
||||
add_barrier();
|
||||
}
|
||||
#else
|
||||
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
@@ -1646,12 +1730,66 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
vals_[cas][{}] = load(atom_ptr);
|
||||
add_barrier();
|
||||
}
|
||||
#endif // defined(USE_ROCM)
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `atomic_rmw`
|
||||
*/
|
||||
#if defined(USE_ROCM)
|
||||
void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
ir::value* ptr = atom->get_operand(0);
|
||||
if (atom->get_op() == ir::atomic_rmw_op_t::Add ||
|
||||
atom->get_op() == ir::atomic_rmw_op_t::Max ||
|
||||
atom->get_op() == ir::atomic_rmw_op_t::Min ||
|
||||
atom->get_op() == ir::atomic_rmw_op_t::UMax ||
|
||||
atom->get_op() == ir::atomic_rmw_op_t::UMin ||
|
||||
atom->get_op() == ir::atomic_rmw_op_t::Xchg) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *rmw_ptr = vals_[atom->get_operand(0)][{}];
|
||||
Value *rmw_val = vals_[atom->get_operand(1)][{}];
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *pred = icmp_eq(tid, i32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
add_barrier();
|
||||
cond_br(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
AtomicRMWInst::BinOp binop;
|
||||
switch (atom->get_op()) {
|
||||
case ir::atomic_rmw_op_t::Add:
|
||||
binop = AtomicRMWInst::Add;
|
||||
break;
|
||||
case ir::atomic_rmw_op_t::Max:
|
||||
binop = AtomicRMWInst::Max;
|
||||
break;
|
||||
case ir::atomic_rmw_op_t::Min:
|
||||
binop = AtomicRMWInst::Min;
|
||||
break;
|
||||
case ir::atomic_rmw_op_t::UMax:
|
||||
binop = AtomicRMWInst::UMax;
|
||||
break;
|
||||
case ir::atomic_rmw_op_t::UMin:
|
||||
binop = AtomicRMWInst::UMin;
|
||||
break;
|
||||
case ir::atomic_rmw_op_t::Xchg:
|
||||
binop = AtomicRMWInst::Xchg;
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Not supported!");
|
||||
}
|
||||
atomic_rmw(binop, rmw_ptr, rmw_val, MaybeAlign(), AtomicOrdering::Monotonic,
|
||||
SyncScope::System);
|
||||
br(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("Not supported!");
|
||||
}
|
||||
#else // defined(USE_ROCM)
|
||||
void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
ir::value *ptr = atom->get_operand(0);
|
||||
ir::value* val = atom->get_operand(1);
|
||||
ir::value* msk = atom->get_operand(2);
|
||||
|
||||
@@ -1756,6 +1894,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // defined(USE_ROCM)
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `mma.884` (V100)
|
||||
@@ -2834,15 +2973,20 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
||||
size_t red_axis = 1;
|
||||
unsigned NK = A_shapes[red_axis];
|
||||
bool is_outer = NK == 1;
|
||||
|
||||
#ifdef USE_ROCM
|
||||
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
||||
#else
|
||||
bool is_mma = layouts_->get(dot)->to_mma();
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80)
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80)
|
||||
return visit_mma884(dot, A, B, D, NK);
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
|
||||
if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
|
||||
A->get_type()->get_scalar_ty()->is_fp32_ty())
|
||||
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
||||
throw std::runtime_error("dot has invalid operand type");
|
||||
#endif
|
||||
}
|
||||
|
||||
void generator::visit_trans_inst(ir::trans_inst* trans) {
|
||||
@@ -2875,8 +3019,14 @@ Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vec
|
||||
|
||||
inline Value* generator::shfl_sync(Value* acc, int32_t i){
|
||||
Type* ty = acc->getType();
|
||||
#ifdef USE_ROCM
|
||||
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
|
||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
|
||||
#else
|
||||
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
|
||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
|
||||
#endif
|
||||
|
||||
if(ty->getPrimitiveSizeInBits() <= 32)
|
||||
return call(shfl, {acc, i32(i)});
|
||||
acc = bit_cast(acc, vec_ty(f32_ty, 2));
|
||||
@@ -3171,12 +3321,16 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
ir::value *arg = x->get_operand(0);
|
||||
#ifdef USE_ROCM
|
||||
visit_reducend_inst(x, do_acc, neutral);
|
||||
#else
|
||||
bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x);
|
||||
bool is_a100_mma = layouts_->is_a100_mma(x);
|
||||
if (is_coalesced_scanline || is_a100_mma)
|
||||
visit_reducend_inst_fast(x, do_acc, neutral);
|
||||
else
|
||||
visit_reducend_inst(x, do_acc, neutral);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -3645,6 +3799,7 @@ Value *generator::cast_shared_layout_ptr(analysis::data_layout *layout,
|
||||
}
|
||||
|
||||
void generator::visit_function(ir::function* fn) {
|
||||
std::cout << "generator.cc: generator::visit_function:" << std::endl;
|
||||
idxs_.clear();
|
||||
vals_.clear();
|
||||
seen_.clear();
|
||||
@@ -3654,6 +3809,7 @@ void generator::visit_function(ir::function* fn) {
|
||||
|
||||
|
||||
// set attributes
|
||||
std::cout << "\t// set attributes" << std::endl;
|
||||
for(auto attr_pair: fn->attrs()){
|
||||
unsigned id = attr_pair.first;
|
||||
for(ir::attribute attr: attr_pair.second)
|
||||
@@ -3664,19 +3820,24 @@ void generator::visit_function(ir::function* fn) {
|
||||
}
|
||||
}
|
||||
// set metadata
|
||||
std::cout << "\t// set metadata" << std::endl;
|
||||
if(tgt_->is_gpu()){
|
||||
tgt_->set_kernel(*builder_, ctx, mod_, ret);
|
||||
#ifndef USE_ROCM
|
||||
Metadata *md_args[] = {
|
||||
ValueAsMetadata::get(ret),
|
||||
MDString::get(ctx, "maxntidx"),
|
||||
ValueAsMetadata::get(i32(num_warps_*32))
|
||||
};
|
||||
mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
|
||||
#endif
|
||||
}
|
||||
// set arguments
|
||||
std::cout << "\t// set arguments" << std::endl;
|
||||
for(unsigned i = 0; i < fn->args().size(); i++)
|
||||
vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i);
|
||||
// create blocks
|
||||
std::cout << "\t// create blocks" << std::endl;
|
||||
auto blocks = ir::cfg::reverse_post_order(fn);
|
||||
for(ir::basic_block *block: blocks) {
|
||||
BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret);
|
||||
@@ -3684,6 +3845,8 @@ void generator::visit_function(ir::function* fn) {
|
||||
}
|
||||
builder_->SetInsertPoint(bbs_[fn->blocks()[0]]);
|
||||
// create policies
|
||||
#ifndef USE_ROCM
|
||||
std::cout << "\t// create policies" << std::endl;
|
||||
if(tgt_->as_nvidia()->sm() >= 80)
|
||||
for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){
|
||||
std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last";
|
||||
@@ -3691,15 +3854,23 @@ void generator::visit_function(ir::function* fn) {
|
||||
InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false);
|
||||
policies_[evict] = call(iasm);
|
||||
}
|
||||
#endif
|
||||
// initialize layouts
|
||||
std::cout << "\t// initialize layouts" << std::endl;
|
||||
for(auto x: layouts_->get_all()){
|
||||
visit_layout(x.second);
|
||||
}
|
||||
// generate LLVM-IR code
|
||||
std::cout << "\t// generate LLVM-IR code" << std::endl;
|
||||
for(ir::basic_block *block: blocks)
|
||||
visit_basic_block(block);
|
||||
// finalize
|
||||
std::cout << "\t// finalize" << std::endl;
|
||||
finalize_function(fn);
|
||||
|
||||
// verifyFunction
|
||||
std::cout << "\t// verifyFunction" << std::endl;
|
||||
llvm::verifyFunction(*ret);
|
||||
}
|
||||
|
||||
|
||||
@@ -3723,7 +3894,11 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
|
||||
Value *_8 = i32(8);
|
||||
Value *_16 = i32(16);
|
||||
Value *_32 = i32(32);
|
||||
#ifdef USE_ROCM
|
||||
int cc = 1; // generate ir for older CUDA cards
|
||||
#else
|
||||
int cc = tgt_->as_nvidia()->sm();
|
||||
#endif
|
||||
std::vector<Value*> idx_m;
|
||||
std::vector<Value*> idx_n;
|
||||
std::vector<Value*> idx_z;
|
||||
@@ -4114,6 +4289,7 @@ void generator::packed_type(ir::value* i){
|
||||
}
|
||||
|
||||
void generator::visit(ir::module &src, llvm::Module &dst) {
|
||||
std::cout << "generator.cc: generator::visit" << std::endl;
|
||||
mod_ = &dst;
|
||||
ctx_ = &dst.getContext();
|
||||
builder_ = new Builder(*ctx_);
|
||||
|
@@ -15,10 +15,22 @@ namespace codegen{
|
||||
|
||||
// base
|
||||
|
||||
|
||||
nvidia_cu_target* target::as_nvidia() {
|
||||
return dynamic_cast<nvidia_cu_target*>(this);
|
||||
#ifdef USE_ROCM
|
||||
amd_cl_target *target::as_amd()
|
||||
{
|
||||
return dynamic_cast<amd_cl_target *>(this);
|
||||
}
|
||||
amd_cl_target *target::as_nvidia()
|
||||
{
|
||||
return this->as_amd();
|
||||
}
|
||||
#else
|
||||
// causes segfault on ROCM
|
||||
nvidia_cu_target *target::as_nvidia()
|
||||
{
|
||||
return dynamic_cast<nvidia_cu_target *>(this);
|
||||
}
|
||||
#endif
|
||||
|
||||
bool target::is_gpu() const {
|
||||
return is_gpu_;
|
||||
@@ -41,7 +53,8 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un
|
||||
}
|
||||
|
||||
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
throw std::runtime_error("not implemented");
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_waitcnt);
|
||||
return builder.CreateIntrinsic(Intrinsic::amdgcn_s_waitcnt, {}, {builder.getInt32(0)});
|
||||
}
|
||||
|
||||
|
||||
@@ -56,7 +69,50 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
throw std::runtime_error("not implemented on AMD");
|
||||
Function &F = *builder.GetInsertBlock()->getParent();
|
||||
Module *Mod = F.getParent();
|
||||
// We are indexing into this struct, and want to extract the grid_size_*
|
||||
// fields.
|
||||
//
|
||||
// typedef struct hsa_kernel_dispatch_packet_s {
|
||||
// uint16_t header;
|
||||
// uint16_t setup;
|
||||
// uint16_t workgroup_size_x ;
|
||||
// uint16_t workgroup_size_y;
|
||||
// uint16_t workgroup_size_z;
|
||||
// uint16_t reserved0;
|
||||
// uint32_t grid_size_x ;
|
||||
// uint32_t grid_size_y ;
|
||||
// uint32_t grid_size_z;
|
||||
// .....
|
||||
// } hsa_kernel_dispatch_packet_t
|
||||
//
|
||||
Function *DispatchPtrFn =
|
||||
Intrinsic::getDeclaration(Mod, Intrinsic::amdgcn_dispatch_ptr);
|
||||
|
||||
CallInst *DispatchPtr = builder.CreateCall(DispatchPtrFn, {});
|
||||
DispatchPtr->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
|
||||
DispatchPtr->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
|
||||
F.removeFnAttr("amdgpu-no-dispatch-ptr");
|
||||
|
||||
// Size of the dispatch packet struct.
|
||||
DispatchPtr->addDereferenceableAttr(AttributeList::ReturnIndex, 64);
|
||||
|
||||
Type *I32Ty = Type::getInt32Ty(Mod->getContext());
|
||||
// TODO: include AMDGPUAS:: declarations.
|
||||
Value *CastDispatchPtr = builder.CreateBitCast(
|
||||
DispatchPtr, PointerType::get(I32Ty, 4 /*AMDGPUAS::CONSTANT_ADDRESS*/));
|
||||
|
||||
// grid_size_x offset is 3*32bit
|
||||
assert(ax < 3);
|
||||
Value *GEP =
|
||||
builder.CreateConstInBoundsGEP1_64(I32Ty, CastDispatchPtr, ax + 3);
|
||||
LoadInst *Load = builder.CreateAlignedLoad(I32Ty, GEP, Align(4));
|
||||
|
||||
MDNode *MD = MDNode::get(Mod->getContext(), None);
|
||||
Load->setMetadata(LLVMContext::MD_invariant_load, MD);
|
||||
|
||||
return Load; // throw std::runtime_error("not implemented on AMD");
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
@@ -156,7 +212,7 @@ Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsi
|
||||
}
|
||||
|
||||
Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
throw std::runtime_error("not implemented");
|
||||
throw std::runtime_error("not implemented on CPU");
|
||||
}
|
||||
|
||||
|
||||
|
@@ -222,6 +222,7 @@ bool dispatch::hipinit(){
|
||||
return res;
|
||||
}
|
||||
|
||||
#define HIP_DEFINE0(ret, fname) DEFINE0(hipinit, hip_, ret, fname)
|
||||
#define HIP_DEFINE1(ret, fname, t1) DEFINE1(hipinit, hip_, ret, fname, t1)
|
||||
#define HIP_DEFINE2(ret, fname, t1, t2) DEFINE2(hipinit, hip_, ret, fname, t1, t2)
|
||||
#define HIP_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3)
|
||||
@@ -278,7 +279,8 @@ HIP_DEFINE2(hipError_t, hipEventCreate, hipEvent_t *, unsigned int)
|
||||
HIP_DEFINE3(hipError_t, hipEventElapsedTime, float *, hipEvent_t, hipEvent_t)
|
||||
HIP_DEFINE2(hipError_t, hipEventRecord, hipEvent_t, hipStream_t)
|
||||
HIP_DEFINE1(hipError_t, hipEventDestroy, hipEvent_t)
|
||||
|
||||
// error handling
|
||||
HIP_DEFINE0(hipError_t, hipGetLastError)
|
||||
|
||||
/* ------------------- *
|
||||
* COMMON
|
||||
|
@@ -25,6 +25,8 @@
|
||||
#endif
|
||||
#include <memory>
|
||||
#include <regex>
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "triton/driver/dispatch.h"
|
||||
#include "triton/driver/error.h"
|
||||
@@ -57,6 +59,8 @@
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Analysis/TargetLibraryInfo.h"
|
||||
#include "llvm/IR/IntrinsicsAMDGPU.h"
|
||||
#include "llvm/IR/Intrinsics.h"
|
||||
// end AMD stuff
|
||||
|
||||
extern "C"
|
||||
@@ -67,6 +71,24 @@ extern "C"
|
||||
int setupterm(char *term, int fildes, int *errret) { return 0; }
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::string gen_random(const int len) {
|
||||
static const char alphanum[] =
|
||||
"0123456789"
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"abcdefghijklmnopqrstuvwxyz";
|
||||
std::string tmp_s;
|
||||
tmp_s.reserve(len);
|
||||
|
||||
for (int i = 0; i < len; ++i) {
|
||||
tmp_s += alphanum[rand() % (sizeof(alphanum) - 1)];
|
||||
}
|
||||
return tmp_s;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
@@ -266,20 +288,24 @@ namespace triton
|
||||
/* ------------------------ */
|
||||
// HIP //
|
||||
/* ------------------------ */
|
||||
|
||||
std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc)
|
||||
std::tuple<std::string, std::string> llir_to_amdgcn(llvm::Module *module, const std::string &_proc)
|
||||
{
|
||||
std::cout << "llvm.cc: llir_to_amdgcn:" << std::endl;
|
||||
init_llvm();
|
||||
|
||||
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||
|
||||
// create
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
std::string triple = "amdgcn-amd-amdhsa";
|
||||
std::string layout = "";
|
||||
std::string features;
|
||||
std::string proc = "gfx908";
|
||||
std::string features = "+sramecc,-xnack";
|
||||
std::string proc = _proc;
|
||||
// name kernel
|
||||
auto in_time_t = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
|
||||
std::stringstream cur_time;
|
||||
cur_time << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d--%I-%M-%S");
|
||||
std::string kernel_name = module->getModuleIdentifier() + "_" + cur_time.str() + "_" + gen_random(12);
|
||||
// verify and store llvm
|
||||
llvm::legacy::PassManager pm;
|
||||
pm.add(llvm::createVerifierPass());
|
||||
@@ -295,7 +321,7 @@ namespace triton
|
||||
opt.NoNaNsFPMath = true;
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
|
||||
llvm::Reloc::PIC_, llvm::None,
|
||||
llvm::CodeGenOpt::Aggressive);
|
||||
llvm::CodeGenOpt::None);
|
||||
// set data layout
|
||||
if (layout.empty())
|
||||
module->setDataLayout(machine->createDataLayout());
|
||||
@@ -308,11 +334,10 @@ namespace triton
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
|
||||
// create dump files
|
||||
std::string module_name = module->getModuleIdentifier();
|
||||
std::error_code ec;
|
||||
|
||||
// Save GCN ISA binary.
|
||||
std::string isabin_path = std::string("/tmp/") + module_name + std::string(".o");
|
||||
std::string isabin_path = std::string("/tmp/") + kernel_name + std::string(".o");
|
||||
std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
|
||||
new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
|
||||
if (ec)
|
||||
@@ -323,15 +348,17 @@ namespace triton
|
||||
// emit
|
||||
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
|
||||
pass.run(*module);
|
||||
|
||||
// Save GCN ISA.
|
||||
std::string amdgcn_path = std::string("/tmp/") + module_name + std::string(".gcn");
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
std::ofstream amdgcn(amdgcn_path);
|
||||
amdgcn << result;
|
||||
amdgcn.close();
|
||||
llvm::SmallVector<char, 0> debugBuffer;
|
||||
llvm::legacy::PassManager debugPass;
|
||||
llvm::raw_svector_ostream debugStream(debugBuffer);
|
||||
machine->addPassesToEmitFile(debugPass, debugStream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile); // TODO:cause segfault on REM ops also cause @llvm.amdgcn.if bug
|
||||
debugPass.run(*module);
|
||||
std::string amdgcn(debugBuffer.begin(), debugBuffer.end());
|
||||
|
||||
// generate HASCO file
|
||||
std::string hsaco_path = std::string("/tmp/") + module_name + std::string(".hsaco");
|
||||
std::string hsaco_path = std::string("/tmp/") + kernel_name + std::string(".hsaco");
|
||||
std::string error_message;
|
||||
int lld_result =
|
||||
llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld",
|
||||
@@ -344,13 +371,14 @@ namespace triton
|
||||
std::cout << lld_result << std::endl;
|
||||
}
|
||||
|
||||
return hsaco_path;
|
||||
return std::make_tuple(amdgcn, hsaco_path);
|
||||
}
|
||||
|
||||
hipModule_t amdgpu_to_hipmodule(const std::string &path)
|
||||
hipModule_t amdgpu_to_hipmodule(const std::string &hsaco_path)
|
||||
{
|
||||
std::cout << "llvm.cc: amdgpu_to_hipmodule:" << std::endl;
|
||||
// Read HSACO.
|
||||
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
|
||||
std::ifstream hsaco_file(hsaco_path, std::ios::binary | std::ios::ate);
|
||||
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
|
||||
|
||||
std::vector<unsigned char> hsaco(hsaco_file_size);
|
||||
|
@@ -60,6 +60,9 @@ value *builder::get_float16(float val)
|
||||
value *builder::get_float32(float val)
|
||||
{ return constant_fp::get(type::get_fp32_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_float64(float val)
|
||||
{ return constant_fp::get(type::get_fp64_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_range(int32_t _lo, int32_t _hi) {
|
||||
constant_int* lo = static_cast<constant_int*>(get_int32(_lo));
|
||||
constant_int* hi = static_cast<constant_int*>(get_int32(_hi));
|
||||
|
@@ -13,6 +13,7 @@ from typing import NamedTuple
|
||||
|
||||
from setuptools import Extension, setup
|
||||
from setuptools.command.build_ext import build_ext
|
||||
import torch
|
||||
|
||||
|
||||
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
|
||||
@@ -32,7 +33,8 @@ def get_build_type():
|
||||
def use_system_llvm():
|
||||
if platform.system() == "Windows":
|
||||
return True
|
||||
versions = ['-11.0', '-11', '-11-64']
|
||||
# versions = ['-11.0', '-11', '-11-64']
|
||||
versions = ['-13.0', '-13', '-13-64']
|
||||
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
||||
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
||||
return any(p is not None for p in paths)
|
||||
@@ -53,7 +55,7 @@ def get_thirdparty_packages(triton_cache_path):
|
||||
if not use_system_llvm():
|
||||
# donwload LLVM if no suitable system LLVM is installed
|
||||
packages.append(
|
||||
Package("llvm", "clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR")
|
||||
Package("llvm", "clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-13.0.0/clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR")
|
||||
)
|
||||
|
||||
thirdparty_cmake_args = []
|
||||
@@ -144,6 +146,9 @@ class CMakeBuild(build_ext):
|
||||
build_args += ["--", '-j' + str(2 * multiprocessing.cpu_count())]
|
||||
|
||||
env = os.environ.copy()
|
||||
|
||||
if torch.version.hip is not None:
|
||||
env["TRITON_USE_ROCM"] = "ON"
|
||||
subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=self.build_temp, env=env)
|
||||
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
|
||||
|
||||
|
@@ -100,7 +100,9 @@ void hip_enqueue(uint64_t stream, uint64_t kernel,
|
||||
drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2,
|
||||
block_0, block_1, block_2,
|
||||
shared_mem, (hipStream_t)stream, nullptr, config);
|
||||
|
||||
#ifdef DEBUG_ROCM
|
||||
drv::dispatch::hipGetLastError();
|
||||
#endif
|
||||
}
|
||||
|
||||
long pow2_divisor(long N){
|
||||
@@ -435,7 +437,7 @@ typedef std::map<std::string, py::object> asm_map_t;
|
||||
// ---------------------------------------
|
||||
|
||||
void init_triton_codegen(py::module &&m) {
|
||||
m.def("compile_ttir",
|
||||
m.def("compile_ttir_to_ptx",
|
||||
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
|
||||
std::ostringstream ttir;
|
||||
int n_shared_bytes;
|
||||
@@ -490,11 +492,64 @@ void init_triton_codegen(py::module &&m) {
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
|
||||
m.def("compile_ttir_to_amdgcn",
|
||||
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, const std::string& gfx_arch) {
|
||||
std::ostringstream ttir;
|
||||
int n_shared_bytes;
|
||||
std::string tmp;
|
||||
std::string amdgcn;
|
||||
std::string hsaco_path;
|
||||
std::string name;
|
||||
{
|
||||
std::cout << "triton.cc: compile_ttir_to_amdgcn:" << std::endl;
|
||||
// Scope where the GIL is released
|
||||
py::gil_scoped_release allow_threads;
|
||||
name = ir.get_function_list()[0]->get_name();
|
||||
ir.print(ttir);
|
||||
llvm::LLVMContext ctx;
|
||||
// construct extern lib map
|
||||
triton::codegen::ExternLibMap extern_lib_map;
|
||||
for (auto item : extern_libs) {
|
||||
auto name = item.first.cast<std::string>();
|
||||
auto path = item.second.cast<std::string>();
|
||||
extern_lib_map.emplace(
|
||||
name, triton::codegen::create_extern_lib(name, path));
|
||||
}
|
||||
int version;
|
||||
// std::string ptxas_path = drv::path_to_ptxas(version);
|
||||
// Triton-IR -> AMDGCN LLVM-IR
|
||||
std::cout << "ttir:" << std::endl;
|
||||
std::cout << "\t" << ttir.str() << std::endl;
|
||||
triton::codegen::amd_cl_target target;
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(
|
||||
ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
|
||||
llvm::raw_string_ostream llir(tmp);
|
||||
llir << *llvm;
|
||||
std::cout << "llir:" << std::endl;
|
||||
std::cout << "\t" << llir.str() << std::endl;
|
||||
llir.flush();
|
||||
// LLVM-IR -> AMDGPU
|
||||
std::tuple<std::string, std::string> amdgpu = drv::llir_to_amdgcn(llvm.get(), gfx_arch);
|
||||
amdgcn = std::get<0>(amdgpu);
|
||||
hsaco_path = std::get<1>(amdgpu);
|
||||
std::cout << "amdgcn:" << std::endl;
|
||||
std::cout << "\t" << amdgcn << std::endl;
|
||||
}
|
||||
asm_map_t asm_map;
|
||||
asm_map["ttir"] = py::cast(ttir.str());
|
||||
asm_map["llir"] = py::cast(tmp);
|
||||
asm_map["amdgcn"] = py::cast(amdgcn);
|
||||
asm_map["hsaco_path"] = py::cast(hsaco_path);
|
||||
|
||||
return std::make_tuple(name, asm_map, n_shared_bytes);
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
|
||||
|
||||
// ---------------------------------------
|
||||
// Load provided assembly code into driver
|
||||
// ---------------------------------------
|
||||
m.def("load_binary", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
|
||||
m.def("load_binary_cubin", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
|
||||
py::gil_scoped_release allow_threads;
|
||||
// create driver handles
|
||||
CUfunction fun;
|
||||
@@ -521,6 +576,47 @@ void init_triton_codegen(py::module &&m) {
|
||||
},
|
||||
py::return_value_policy::take_ownership
|
||||
);
|
||||
|
||||
// ---------------------------------------
|
||||
// Load provided assembly code into driver
|
||||
// ---------------------------------------
|
||||
m.def("load_binary_hsaco", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
|
||||
std::cout << "triton.cc: load_binary_hsaco:" << std::endl;
|
||||
std::cout << "\tname:" << name << std::endl;
|
||||
std::cout << "\tdata:" << data << std::endl;
|
||||
std::cout << "\tn_shared_bytes:" << n_shared_bytes << std::endl;
|
||||
std::cout << "\tdevice:" << device << std::endl;
|
||||
py::gil_scoped_release allow_threads;
|
||||
// create driver handles
|
||||
std::cout << "\t" << "// create driver handles" << std::endl;
|
||||
hipFunction_t fun;
|
||||
hipModule_t mod = drv::amdgpu_to_hipmodule(data);
|
||||
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
||||
// get allocated registers and spilled registers from the function
|
||||
std::cout << "\t" << "// get allocated registers and spilled registers from the function" << std::endl;
|
||||
int n_regs = 0;
|
||||
int n_spills = 0;
|
||||
hipFuncAttributes attr;
|
||||
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||
n_regs = attr.numRegs;
|
||||
n_spills = attr.localSizeBytes / 4;
|
||||
// set dynamic shared memory if necessary
|
||||
std::cout << "\t" << "// set dynamic shared memory if necessary" << std::endl;
|
||||
int shared_optin;
|
||||
// drv::dispatch::hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device);
|
||||
if(n_shared_bytes > 49152 && shared_optin > 49152){
|
||||
// drv::dispatch::hipFuncSetCacheConfig(fun, hipFuncCachePreferShared);
|
||||
int shared_total, shared_static;
|
||||
// drv::dispatch::hipDeviceGetAttribute(&shared_total, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device);
|
||||
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||
shared_total = attr.sharedSizeBytes;
|
||||
// drv::dispatch::hipFuncSetAttribute(fun, hipFuncAttributeMaxDynamicSharedMemorySize, shared_optin - shared_static);
|
||||
}
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills);
|
||||
},
|
||||
py::return_value_policy::take_ownership
|
||||
);
|
||||
|
||||
|
||||
struct InstanceDescriptor
|
||||
|
@@ -104,9 +104,12 @@ def check_type_supported(dtype):
|
||||
'''
|
||||
skip test if dtype is not supported on the current device
|
||||
'''
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
if torch.version.hip is not None:
|
||||
pass
|
||||
else:
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
|
||||
@@ -123,6 +126,9 @@ def test_empty_kernel(dtype_x, device='cuda'):
|
||||
|
||||
# generic test functions
|
||||
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
||||
if torch.version.hip is not None:
|
||||
if dtype_x == "bfloat16":
|
||||
pytest.skip("unary op with bfloat is not supported on AMDGPU")
|
||||
check_type_supported(dtype_x) # early return if dtype_x is not supported
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@@ -230,7 +236,6 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
('int64', 'float32'),
|
||||
('int64', 'float64'),
|
||||
('uint16', 'bfloat16'),
|
||||
('uint16', 'float16'),
|
||||
('uint16', 'float32'),
|
||||
('uint32', 'bfloat16'),
|
||||
('uint32', 'float16'),
|
||||
@@ -253,8 +258,12 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
for dtype_y in dtypes_with_bfloat16
|
||||
])
|
||||
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
if torch.version.hip is not None:
|
||||
if dtype_x == "bfloat16" and dtype_y == "bfloat16" :
|
||||
pytest.skip("binary op with bfloat is not supported on AMDGPU")
|
||||
|
||||
expr = f' x {op} y'
|
||||
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
|
||||
if op == '%' and (dtype_x in dtypes and dtype_y in dtypes):
|
||||
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
|
||||
numpy_expr = 'np.fmod(x, y)'
|
||||
elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', 'bfloat16'):
|
||||
@@ -605,6 +614,10 @@ def test_tuples():
|
||||
]
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
if torch.version.hip is not None:
|
||||
# if dtype_x_str in ["uint32","int32","float32"]:
|
||||
pytest.skip(f"test_atomic_rmw[{dtype_x_str}] currently has segfaults on ROCM")
|
||||
|
||||
n_programs = 5
|
||||
|
||||
# triton kernel
|
||||
@@ -671,6 +684,8 @@ def test_tensor_atomic_rmw(axis, device="cuda"):
|
||||
|
||||
|
||||
def test_atomic_cas():
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip(f"test_atomic_cas currently has segfaults on ROCM")
|
||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||
@triton.jit
|
||||
def change_value(Lock):
|
||||
@@ -778,6 +793,8 @@ def test_store_bool():
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_f8_xf16_roundtrip(dtype):
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip(f"test_f8_xf16_roundtrip currently has segfaults on ROCM")
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
check_type_supported(dtype)
|
||||
|
||||
@@ -804,6 +821,8 @@ def test_f8_xf16_roundtrip(dtype):
|
||||
|
||||
|
||||
def test_f16_to_f8_rounding():
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip(f"test_atomic_cas currently has segfaults on ROCM")
|
||||
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||
error is the minimum over all float8.
|
||||
Or the same explanation a bit mathier:
|
||||
@@ -872,6 +891,9 @@ def test_f16_to_f8_rounding():
|
||||
for shape in [32, 64, 128, 512]])
|
||||
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
if torch.version.hip is not None:
|
||||
if dtype_str in ["int8", "int16", "uint8", "uint16"]:
|
||||
pytest.skip(f"test_reduce1d[{dtype_str}] skipped on ROCM")
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -930,6 +952,10 @@ reduce_configs2 = [
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
|
||||
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
if torch.version.hip is not None:
|
||||
if dtype_str in ["int8", "int16", "uint8", "uint16"]:
|
||||
pytest.skip(f"test_reduce2d[{dtype_str}] skipped on ROCM")
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||
@@ -1021,13 +1047,18 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
|
||||
# parse ptx to make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
ptx = pgm_contiguous.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
if torch.version.hip is None:
|
||||
# parse ptx to make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
ptx = pgm_contiguous.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
else:
|
||||
# TODO add rocm gcn assert
|
||||
pass
|
||||
|
||||
# ---------------
|
||||
# test dot
|
||||
@@ -1041,12 +1072,15 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
for dtype in ['float16']
|
||||
if not (allow_tf32 and (dtype in ['float16']))])
|
||||
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80:
|
||||
if dtype == 'int8':
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
if torch.version.hip is not None:
|
||||
pass
|
||||
else:
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80:
|
||||
if dtype == 'int8':
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
|
||||
M, N, K = 128, 128, 64
|
||||
num_warps = 8
|
||||
@@ -1141,15 +1175,18 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32':
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
if torch.version.hip is not None:
|
||||
pass
|
||||
else:
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32':
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
def test_dot_without_load():
|
||||
@@ -1227,6 +1264,8 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
|
||||
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip(f"test_masked_load_shared_memory currently has segfaults on ROCM")
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
M = 32
|
||||
@@ -1286,16 +1325,20 @@ def test_load_cache_modifier(cache):
|
||||
tl.store(dst + offsets, x)
|
||||
|
||||
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||
ptx = pgm.asm['ptx']
|
||||
if cache == '':
|
||||
assert 'ld.global.ca' not in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if cache == '.cg':
|
||||
assert 'ld.global.cg' in ptx
|
||||
assert 'ld.global.ca' not in ptx
|
||||
if cache == '.ca':
|
||||
assert 'ld.global.ca' in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if torch.version.hip is None:
|
||||
ptx = pgm.asm['ptx']
|
||||
if cache == '':
|
||||
assert 'ld.global.ca' not in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if cache == '.cg':
|
||||
assert 'ld.global.cg' in ptx
|
||||
assert 'ld.global.ca' not in ptx
|
||||
if cache == '.ca':
|
||||
assert 'ld.global.ca' in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
else:
|
||||
# TODO add rocm gcn assert
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
|
||||
@@ -1309,11 +1352,15 @@ def test_vectorization(N):
|
||||
x = tl.load(src + offsets, mask=offsets < N)
|
||||
tl.store(dst + offsets, x, mask=offsets < N)
|
||||
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
||||
ptx = pgm.asm["ptx"]
|
||||
if N % 16 == 0:
|
||||
assert "ld.global.v4.b32" in ptx
|
||||
if torch.version.hip is None:
|
||||
ptx = pgm.asm["ptx"]
|
||||
if N % 16 == 0:
|
||||
assert "ld.global.v4.b32" in ptx
|
||||
else:
|
||||
assert "ld.global.b32" in ptx
|
||||
else:
|
||||
assert "ld.global.b32" in ptx
|
||||
#TODO add rocm assert
|
||||
pass
|
||||
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||
# ---------------
|
||||
# test store
|
||||
@@ -1547,7 +1594,8 @@ def test_num_warps_pow2():
|
||||
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||
('float64', 'libdevice.norm4d', '')])
|
||||
def test_libdevice_tensor(dtype_str, expr, lib_path):
|
||||
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip(f"test_libdevice_tensor currently has segfaults on ROCM")
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
@@ -1587,6 +1635,8 @@ def test_libdevice_tensor(dtype_str, expr, lib_path):
|
||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
[('float32', 'libdevice.pow', '')])
|
||||
def test_libdevice_scalar(dtype_str, expr, lib_path):
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip(f"test_libdevice_scalar currently has segfaults on ROCM")
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
|
@@ -7,6 +7,7 @@ import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -24,6 +25,12 @@ import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .tools.disasm import extract
|
||||
|
||||
def static_vars(**kwargs):
|
||||
def decorate(func):
|
||||
for k in kwargs:
|
||||
setattr(func, k, kwargs[k])
|
||||
return func
|
||||
return decorate
|
||||
|
||||
def str_to_ty(name):
|
||||
if name[0] == "*":
|
||||
@@ -880,31 +887,55 @@ def ptx_get_kernel_name(ptx: str) -> str:
|
||||
if line.startswith('// .globl'):
|
||||
return line.split()[-1]
|
||||
|
||||
@functools.lru_cache()
|
||||
def rocm_path_dir():
|
||||
return os.getenv("ROCM_PATH", default="/opt/rocm")
|
||||
|
||||
def _get_amdgpu_arch():
|
||||
try:
|
||||
rocminfo = subprocess.check_output(rocm_path_dir() + '/bin/rocminfo').decode()
|
||||
gfx_arch = re.search('Name:\\s+.*(gfx\\d+)', rocminfo)
|
||||
return gfx_arch.group(1).strip()
|
||||
except:
|
||||
return None
|
||||
|
||||
@static_vars(discovered_gfx_arch = _get_amdgpu_arch())
|
||||
def _compile(fn, signature: str, device: int = -1, constants=dict(),
|
||||
specialization=_triton.code_gen.instance_descriptor(),
|
||||
num_warps: int = 4, num_stages: int = 3, extern_libs=None,
|
||||
output: str = "ttgir", cc=0) -> Tuple[str, int, str]:
|
||||
print("compiler.py: _compile")
|
||||
print(f"\t{fn, signature, device, constants, specialization, num_warps, num_stages, extern_libs, output, cc}")
|
||||
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
# assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
|
||||
# triton-ir
|
||||
module, _ = make_triton_ir(fn, signature, specialization, constants)
|
||||
if output == "ttir":
|
||||
return module
|
||||
|
||||
assert output == "cubin"
|
||||
assert torch.version.hip is None
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
assert (output == "cubin" or output == "hsaco")
|
||||
if torch.version.hip is not None:
|
||||
backend = _triton.runtime.backend.ROCM
|
||||
else:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if extern_libs is None:
|
||||
extern_libs = dict()
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
||||
|
||||
# compile ttir
|
||||
if torch.version.hip is not None:
|
||||
gfx_arch = os.environ.get('MI_GPU_ARCH', _compile.discovered_gfx_arch)
|
||||
if gfx_arch is None:
|
||||
raise RuntimeError('AMDGCN gfx arch is not defined.')
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgcn(backend, module, device, num_warps, num_stages, extern_libs, gfx_arch)
|
||||
else:
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_ptx(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
||||
return asm, shared_mem, name
|
||||
|
||||
|
||||
def ty_to_cpp(ty):
|
||||
if ty[0] == '*':
|
||||
return "CUdeviceptr"
|
||||
return "hipDeviceptr_t"
|
||||
return {
|
||||
"i1": "int32_t",
|
||||
"i8": "int8_t",
|
||||
@@ -967,17 +998,18 @@ def generate_launcher(identifier, constants, signature):
|
||||
format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
|
||||
|
||||
# generate glue code
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
if torch.version.hip is not None:
|
||||
src = f"""
|
||||
#define __HIP_PLATFORM_AMD__
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <Python.h>
|
||||
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
static inline void gpuAssert(hipError_t code, const char *file, int line)
|
||||
{{
|
||||
if (code != CUDA_SUCCESS)
|
||||
if (code != HIP_SUCCESS)
|
||||
{{
|
||||
const char* prefix = "Triton Error [CUDA]: ";
|
||||
const char* str;
|
||||
cuGetErrorString(code, &str);
|
||||
const char* str = hipGetErrorString(code);
|
||||
char err[1024] = {{0}};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
@@ -987,20 +1019,20 @@ static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
|
||||
|
||||
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
|
||||
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||
if(gridX*gridY*gridZ > 0){{
|
||||
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
|
||||
hipModuleLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0);
|
||||
}}
|
||||
}}
|
||||
|
||||
|
||||
static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
|
||||
static inline hipDeviceptr_t getPointer(PyObject *obj, int idx) {{
|
||||
if (PyLong_Check(obj)) {{
|
||||
return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj);
|
||||
return (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
|
||||
}}
|
||||
if (obj == Py_None) {{
|
||||
return (CUdeviceptr)0;
|
||||
return (hipDeviceptr_t)0;
|
||||
}}
|
||||
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
||||
if(ptr){{
|
||||
@@ -1011,14 +1043,15 @@ static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
|
||||
if (!PyLong_Check(ret)) {{
|
||||
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
||||
}}
|
||||
return (CUdeviceptr)PyLong_AsUnsignedLongLong(ret);
|
||||
return (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
||||
}}
|
||||
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||
return (CUdeviceptr)0;
|
||||
return (hipDeviceptr_t)0;
|
||||
}}
|
||||
|
||||
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
// printf("launch(PyObject* self, PyObject* args)");
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
uint64_t _function;
|
||||
@@ -1039,7 +1072,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
Py_DECREF(new_args);
|
||||
}}
|
||||
|
||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
|
||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
|
||||
|
||||
if (launch_exit_hook != Py_None) {{
|
||||
PyObject *new_args = NULL;
|
||||
@@ -1134,6 +1167,10 @@ def libcuda_dirs():
|
||||
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
|
||||
return [os.path.dirname(loc) for loc in locs]
|
||||
|
||||
@functools.lru_cache()
|
||||
def libhip_dirs():
|
||||
return ["/opt/rocm/lib"]
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def cuda_home_dirs():
|
||||
@@ -1141,6 +1178,15 @@ def cuda_home_dirs():
|
||||
return os.getenv("CUDA_HOME", default=default_dir)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def hip_home_dirs():
|
||||
default_dir = "/opt/rocm"
|
||||
return os.getenv("ROCM_HOME", default=default_dir)
|
||||
|
||||
@functools.lru_cache()
|
||||
def rocm_path_dir():
|
||||
return os.getenv("ROCM_PATH", default="/opt/rocm")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def quiet():
|
||||
old_stdout, old_stderr = sys.stdout, sys.stderr
|
||||
@@ -1152,8 +1198,15 @@ def quiet():
|
||||
|
||||
|
||||
def _build(name, src, srcdir):
|
||||
cuda_lib_dirs = libcuda_dirs()
|
||||
cu_include_dir = os.path.join(cuda_home_dirs(), "include")
|
||||
print("compiler.py: _build")
|
||||
print(f"\t{name, src, srcdir}")
|
||||
if torch.version.hip is not None:
|
||||
hip_lib_dirs = libhip_dirs()
|
||||
hip_include_dir = os.path.join(hip_home_dirs(), "include")
|
||||
else:
|
||||
cuda_lib_dirs = libcuda_dirs()
|
||||
cu_include_dir = os.path.join(cuda_home_dirs(), "include")
|
||||
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
||||
# try to avoid setuptools if possible
|
||||
@@ -1164,16 +1217,29 @@ def _build(name, src, srcdir):
|
||||
gcc = shutil.which("gcc")
|
||||
cc = gcc if gcc is not None else clang
|
||||
py_include_dir = get_paths()["include"]
|
||||
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
|
||||
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
|
||||
if torch.version.hip is not None:
|
||||
cc_cmd = [cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lamdhip64", "-o", so]
|
||||
cc_cmd += [f"-L{dir}" for dir in hip_lib_dirs]
|
||||
else:
|
||||
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
|
||||
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
|
||||
print("\t", ''.join(cc_cmd))
|
||||
ret = subprocess.check_call(cc_cmd)
|
||||
if ret == 0:
|
||||
print("ret:", ret)
|
||||
print(so)
|
||||
return so
|
||||
# fallback on setuptools
|
||||
extra_compile_args = []
|
||||
library_dirs = cuda_lib_dirs
|
||||
include_dirs = [srcdir, cu_include_dir]
|
||||
libraries = ['cuda']
|
||||
if torch.version.hip is not None:
|
||||
library_dirs = hip_lib_dirs
|
||||
include_dirs = [srcdir, hip_include_dir]
|
||||
libraries = ['rocm']
|
||||
else:
|
||||
library_dirs = cuda_lib_dirs
|
||||
include_dirs = [srcdir, cu_include_dir]
|
||||
libraries = ['cuda']
|
||||
|
||||
# extra arguments
|
||||
extra_link_args = []
|
||||
# create extension module
|
||||
@@ -1221,6 +1287,8 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta
|
||||
|
||||
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4,
|
||||
num_stages: int = 3, extern_libs=None, configs=None, cc=0, warm_cache_only=False):
|
||||
print("compiler.py: compile")
|
||||
print(f"\t{fn, signature, device, constants, num_warps, num_stages, extern_libs, configs, cc, warm_cache_only}")
|
||||
# we get the kernel, i.e. the first function generated in the module
|
||||
assert len(configs) == 1
|
||||
# cache manager
|
||||
@@ -1236,30 +1304,52 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = _build(fn.__name__, src_path, tmpdir)
|
||||
so = _build(fn.__name__, src_path, tmpdir) # build step
|
||||
with open(so, "rb") as f:
|
||||
so_cache_manager.put(f.read(), so_name, binary=True)
|
||||
|
||||
# retrieve cached shared object if it exists
|
||||
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
|
||||
fn_cache_manager = CacheManager(fn_cache_key)
|
||||
ptx_name = f"{name}.ptx"
|
||||
cubin_name = f"{name}.cubin"
|
||||
if torch.version.hip is not None:
|
||||
amdgcn_name = f"{name}.gcn"
|
||||
hasco_name = f"{name}.hsaco"
|
||||
assembly_name = amdgcn_name
|
||||
binary_name = hasco_name
|
||||
else:
|
||||
ptx_name = f"{name}.ptx"
|
||||
cubin_name = f"{name}.cubin"
|
||||
assembly_name = ptx_name
|
||||
binary_name = cubin_name
|
||||
|
||||
data_name = f"{name}.json"
|
||||
ttir_name = f"{name}.ttir"
|
||||
llir_name = f"{name}.llir"
|
||||
if not fn_cache_manager.has_file(cubin_name) or \
|
||||
if not fn_cache_manager.has_file(binary_name) or \
|
||||
not fn_cache_manager.has_file(data_name) or \
|
||||
not fn_cache_manager.has_file(ptx_name) or \
|
||||
not fn_cache_manager.has_file(assembly_name) or \
|
||||
not fn_cache_manager.has_file(ttir_name) or \
|
||||
not fn_cache_manager.has_file(llir_name):
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||
extern_libs, "cubin", cc)
|
||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||
fn_cache_manager.put(asm["cubin"], cubin_name)
|
||||
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
|
||||
|
||||
if torch.version.hip is not None:
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||
extern_libs, "hsaco", cc)
|
||||
# cache AMD assembly and binary
|
||||
fn_cache_manager.put(asm["hsaco_path"], binary_name, binary=False)
|
||||
fn_cache_manager.put(asm["amdgcn"], assembly_name, binary=False)
|
||||
else:
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||
extern_libs, "cubin", cc)
|
||||
# cache Nvidia assembly and binary
|
||||
fn_cache_manager.put(asm["cubin"], binary_name)
|
||||
fn_cache_manager.put(asm["ptx"], assembly_name, binary=False)
|
||||
|
||||
# cache triton and llvm ir
|
||||
fn_cache_manager.put(asm["ttir"], ttir_name, binary=False)
|
||||
fn_cache_manager.put(asm["llir"], llir_name, binary=False)
|
||||
|
||||
# cache metadata
|
||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
|
||||
|
||||
if warm_cache_only:
|
||||
@@ -1275,9 +1365,12 @@ class CompiledKernel:
|
||||
launch_exit_hook = None
|
||||
|
||||
def __init__(self, fn_name, so_path, cache_dir, device):
|
||||
print("compiler.py: CompiledKernel:__init__")
|
||||
print(f"\t{self, fn_name, so_path, cache_dir, device}")
|
||||
# initialize launcher
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
||||
print("spec:", spec)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
self.c_wrapper = getattr(mod, "launch")
|
||||
@@ -1289,16 +1382,25 @@ class CompiledKernel:
|
||||
self.num_stages = metadata["num_stages"]
|
||||
# initialize asm dict
|
||||
self.asm = dict()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
||||
self.asm["cubin"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||
self.asm["ptx"] = f.read()
|
||||
if torch.version.hip is not None:
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.hsaco"), "rb") as f:
|
||||
self.asm["hsaco_path"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.gcn"), "r") as f:
|
||||
self.asm["amdgcn"] = f.read()
|
||||
else:
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
||||
self.asm["cubin"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||
self.asm["ptx"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.llir"), "r") as f:
|
||||
self.asm["llir"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ttir"), "r") as f:
|
||||
self.asm["ttir"] = f.read()
|
||||
|
||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
if torch.version.hip is not None:
|
||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_hsaco(metadata["name"], self.asm["hsaco_path"], self.shared, device)
|
||||
else:
|
||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_cubin(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
self.fn_name = fn_name
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
|
@@ -6,6 +6,8 @@ from . import core as tl
|
||||
from triton._C.libtriton.triton import ir
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
# Create custom exception that prints message "hello"
|
||||
class IncompatibleTypeErrorimpl(Exception):
|
||||
def __init__(self, type_a, type_b):
|
||||
@@ -969,6 +971,11 @@ def dot(a: tl.tensor,
|
||||
trans_b: bool,
|
||||
allow_tf32: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
|
||||
if torch.version.hip is not None:
|
||||
a = cast(a, tl.float32, builder)
|
||||
b = cast(b, tl.float32, builder)
|
||||
|
||||
in_a = 1 if not trans_a else 0
|
||||
in_b = 1 if trans_b else 0
|
||||
assert a.type.is_block() and b.type.is_block()
|
||||
|
14
triton_rocm_20-52.Dockerfile
Normal file
14
triton_rocm_20-52.Dockerfile
Normal file
@@ -0,0 +1,14 @@
|
||||
FROM rocm/pytorch:rocm5.2.3_ubuntu20.04_py3.7_pytorch_1.12.1
|
||||
|
||||
# build triton
|
||||
RUN export TRITON_USE_ROCM=ON MI_GPU_ARCH=gfx90a
|
||||
|
||||
# Unit Tests
|
||||
# to run unit tests
|
||||
# 1. build this Dockerfile
|
||||
# docker build --build-arg -f triton_rocm_20-52.Dockerfile -t triton_rocm52 .
|
||||
# 2. run docker container
|
||||
# docker run -it --rm --network=host --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --name triton --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri triton_rocm52:latest
|
||||
# 3. run core unit tests on a rocm machine
|
||||
# cd ~/triton/python
|
||||
# pytest --verbose test/unit/language/test_core.py | tee test_core.log
|
Reference in New Issue
Block a user