From 406d03bfaf5c2824a1743ed38ae4e33f19d5b98b Mon Sep 17 00:00:00 2001 From: Daniil Fukalov <1671137+dfukalov@users.noreply.github.com> Date: Fri, 14 Oct 2022 21:33:42 +0300 Subject: [PATCH] Improve ROCm support. (#780) - updates to support ROCm 5.2 - workarounds in tests where NV tools were used unconditionally - implemented `get_num_blocks()` and `add_memfence()` for AMD GPU - backported from history some atomics - added bf16 support - minor warnings cleanup - added dockerfile to run on a ROCm enabled machine Co-authored-by: B1tway Co-authored-by: Andrey Shukshov <36711069+B1tway@users.noreply.github.com> --- CMakeLists.txt | 5 + include/triton/codegen/analysis/layout.h | 5 +- include/triton/driver/dispatch.h | 5 +- include/triton/driver/llvm.h | 4 + include/triton/external/hip.h | 319 +++++++++++++-------- include/triton/ir/type.h | 1 + lib/codegen/selection/generator.cc | 152 +++++++++- lib/codegen/target.cc | 48 +++- lib/driver/dispatch.cc | 4 +- lib/driver/llvm.cc | 2 +- lib/ir/dispatch.cc | 2 + python/src/triton.cc | 14 +- python/test/regression/test_performance.py | 4 +- python/test/unit/language/test_core.py | 3 - python/test/unit/operators/test_matmul.py | 2 +- python/triton/testing.py | 6 +- triton_rocm_20-52.Dockerfile | 14 + 17 files changed, 435 insertions(+), 155 deletions(-) create mode 100644 triton_rocm_20-52.Dockerfile diff --git a/CMakeLists.txt b/CMakeLists.txt index 94808a586..b19e2ed58 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,6 +41,11 @@ endif() if (TRITON_USE_ROCM) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-unused-result -Wno-attributes") + set(MI_GPU_ARCH $ENV{MI_GPU_ARCH}) + if (NOT MI_GPU_ARCH) + set(MI_GPU_ARCH "gfx90a") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMI_GPU_ARCH=${MI_GPU_ARCH}") else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17") endif() diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 4d12e34c0..ec13cbb6c 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -137,7 +137,8 @@ private: std::vector rep_; }; -struct scanline_layout: public distributed_layout { +class scanline_layout: public distributed_layout { +public: scanline_layout(size_t num_warps, const std::vector& axes, const std::vector& shape, @@ -149,7 +150,7 @@ struct scanline_layout: public distributed_layout { int mts(size_t k) { return mts_.at(k); } int nts(size_t k) { return nts_.at(k); } -public: +private: // micro tile size. The size of a tile held by a thread block. std::vector mts_; // nano tile size. The size of a tile held by a thread. diff --git a/include/triton/driver/dispatch.h b/include/triton/driver/dispatch.h index ccecf604a..9535d0cc8 100755 --- a/include/triton/driver/dispatch.h +++ b/include/triton/driver/dispatch.h @@ -181,7 +181,8 @@ public: static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart, hipEvent_t hEnd); static hipError_t hipEventRecord(hipEvent_t hEvent, hipStream_t hStream); static hipError_t hipEventDestroy(hipEvent_t hEvent); - + // error handling + static hipError_t hipGetLastError(void); private: @@ -306,6 +307,8 @@ private: static void* hipEventElapsedTime_; static void* hipEventRecord_; static void* hipEventDestroy_; + // error handling + static void* hipGetLastError_; }; } diff --git a/include/triton/driver/llvm.h b/include/triton/driver/llvm.h index 89dc98169..89cf339a3 100644 --- a/include/triton/driver/llvm.h +++ b/include/triton/driver/llvm.h @@ -17,3 +17,7 @@ hipModule_t amdgpu_to_hipmodule(const std::string& path); } } + +#define STRINGIFY_HELPER(X) #X +#define STRINGIFY(X) STRINGIFY_HELPER(X) + diff --git a/include/triton/external/hip.h b/include/triton/external/hip.h index a099f857b..f88463749 100644 --- a/include/triton/external/hip.h +++ b/include/triton/external/hip.h @@ -1,13 +1,35 @@ /* - * @brief hipError_t - * @enum - * @ingroup Enumerations - */ -// Developer note - when updating these, update the hipErrorName and hipErrorString functions in -// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path. +Copyright (c) 2015 - 2022 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_H +#define HIP_H // Ignoring error-code return values from hip APIs is discouraged. On C++17, // we can make that yield a warning +#if __cplusplus >= 201703L +#define __HIP_NODISCARD [[nodiscard]] +#else +#define __HIP_NODISCARD +#endif /* * @brief hipError_t @@ -17,9 +39,7 @@ // Developer note - when updating these, update the hipErrorName and hipErrorString functions in // NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path. -#include - -typedef enum hipError_t { +typedef enum __HIP_NODISCARD hipError_t { hipSuccess = 0, ///< Successful completion. hipErrorInvalidValue = 1, ///< One or more of the parameters passed to the API call is NULL ///< or not in an acceptable range. @@ -73,6 +93,7 @@ typedef enum hipError_t { hipErrorInvalidHandle = 400, // Deprecated hipErrorInvalidResourceHandle = 400, ///< Resource handle (hipEvent_t or hipStream_t) invalid. + hipErrorIllegalState = 401, ///< Resource required is not in a valid state to perform operation. hipErrorNotFound = 500, hipErrorNotReady = 600, ///< Indicates that asynchronous operations enqueued earlier are not ///< ready. This is not actually an error, but is used to distinguish @@ -86,6 +107,7 @@ typedef enum hipError_t { hipErrorPeerAccessNotEnabled = 705, ///< Peer access was never enabled from the current device. hipErrorSetOnActiveProcess = 708, + hipErrorContextIsDestroyed = 709, hipErrorAssert = 710, ///< Produced when the kernel calls assert. hipErrorHostMemoryAlreadyRegistered = 712, ///< Produced when trying to lock a page-locked memory. @@ -98,6 +120,32 @@ typedef enum hipError_t { ///< that was launched via cooperative launch APIs exceeds the maximum number of ///< allowed blocks for the current device hipErrorNotSupported = 801, ///< Produced when the hip API is not supported/implemented + hipErrorStreamCaptureUnsupported = 900, ///< The operation is not permitted when the stream + ///< is capturing. + hipErrorStreamCaptureInvalidated = 901, ///< The current capture sequence on the stream + ///< has been invalidated due to a previous error. + hipErrorStreamCaptureMerge = 902, ///< The operation would have resulted in a merge of + ///< two independent capture sequences. + hipErrorStreamCaptureUnmatched = 903, ///< The capture was not initiated in this stream. + hipErrorStreamCaptureUnjoined = 904, ///< The capture sequence contains a fork that was not + ///< joined to the primary stream. + hipErrorStreamCaptureIsolation = 905, ///< A dependency would have been created which crosses + ///< the capture sequence boundary. Only implicit + ///< in-stream ordering dependencies are allowed + ///< to cross the boundary + hipErrorStreamCaptureImplicit = 906, ///< The operation would have resulted in a disallowed + ///< implicit dependency on a current capture sequence + ///< from hipStreamLegacy. + hipErrorCapturedEvent = 907, ///< The operation is not permitted on an event which was last + ///< recorded in a capturing stream. + hipErrorStreamCaptureWrongThread = 908, ///< A stream capture sequence not initiated with + ///< the hipStreamCaptureModeRelaxed argument to + ///< hipStreamBeginCapture was passed to + ///< hipStreamEndCapture in a different thread. + hipErrorGraphExecUpdateFailure = 910, ///< This error indicates that the graph update + ///< not performed because it included changes which + ///< violated constraintsspecific to instantiated graph + ///< update. hipErrorUnknown = 999, //< Unknown error. // HSA Runtime Error Codes start here. hipErrorRuntimeMemory = 1052, ///< HSA runtime memory call returned error. Typically not seen @@ -107,35 +155,154 @@ typedef enum hipError_t { hipErrorTbd ///< Marker that more error codes are needed. } hipError_t; +#undef __HIP_NODISCARD +/* + * @brief hipDeviceAttribute_t + * @enum + * @ingroup Enumerations + */ +typedef enum hipDeviceAttribute_t { + hipDeviceAttributeCudaCompatibleBegin = 0, + + hipDeviceAttributeEccEnabled = hipDeviceAttributeCudaCompatibleBegin, ///< Whether ECC support is enabled. + hipDeviceAttributeAccessPolicyMaxWindowSize, ///< Cuda only. The maximum size of the window policy in bytes. + hipDeviceAttributeAsyncEngineCount, ///< Cuda only. Asynchronous engines number. + hipDeviceAttributeCanMapHostMemory, ///< Whether host memory can be mapped into device address space + hipDeviceAttributeCanUseHostPointerForRegisteredMem,///< Cuda only. Device can access host registered memory + ///< at the same virtual address as the CPU + hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz. + hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in. + hipDeviceAttributeComputePreemptionSupported, ///< Cuda only. Device supports Compute Preemption. + hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple kernels concurrently. + hipDeviceAttributeConcurrentManagedAccess, ///< Device can coherently access managed memory concurrently with the CPU + hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch + hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative launch on multiple devices + hipDeviceAttributeDeviceOverlap, ///< Cuda only. Device can concurrently copy memory and execute a kernel. + ///< Deprecated. Use instead asyncEngineCount. + hipDeviceAttributeDirectManagedMemAccessFromHost, ///< Host can directly access managed memory on + ///< the device without migration + hipDeviceAttributeGlobalL1CacheSupported, ///< Cuda only. Device supports caching globals in L1 + hipDeviceAttributeHostNativeAtomicSupported, ///< Cuda only. Link between the device and the host supports native atomic operations + hipDeviceAttributeIntegrated, ///< Device is integrated GPU + hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices. + hipDeviceAttributeKernelExecTimeout, ///< Run time limit for kernels executed on the device + hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device doesn't have L2 cache. + hipDeviceAttributeLocalL1CacheSupported, ///< caching locals in L1 is supported + hipDeviceAttributeLuid, ///< Cuda only. 8-byte locally unique identifier in 8 bytes. Undefined on TCC and non-Windows platforms + hipDeviceAttributeLuidDeviceNodeMask, ///< Cuda only. Luid device node mask. Undefined on TCC and non-Windows platforms + hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability version number. + hipDeviceAttributeManagedMemory, ///< Device supports allocating managed memory on this system + hipDeviceAttributeMaxBlocksPerMultiProcessor, ///< Cuda only. Max block size per multiprocessor + hipDeviceAttributeMaxBlockDimX, ///< Max block size in width. + hipDeviceAttributeMaxBlockDimY, ///< Max block size in height. + hipDeviceAttributeMaxBlockDimZ, ///< Max block size in depth. + hipDeviceAttributeMaxGridDimX, ///< Max grid size in width. + hipDeviceAttributeMaxGridDimY, ///< Max grid size in height. + hipDeviceAttributeMaxGridDimZ, ///< Max grid size in depth. + hipDeviceAttributeMaxSurface1D, ///< Maximum size of 1D surface. + hipDeviceAttributeMaxSurface1DLayered, ///< Cuda only. Maximum dimensions of 1D layered surface. + hipDeviceAttributeMaxSurface2D, ///< Maximum dimension (width, height) of 2D surface. + hipDeviceAttributeMaxSurface2DLayered, ///< Cuda only. Maximum dimensions of 2D layered surface. + hipDeviceAttributeMaxSurface3D, ///< Maximum dimension (width, height, depth) of 3D surface. + hipDeviceAttributeMaxSurfaceCubemap, ///< Cuda only. Maximum dimensions of Cubemap surface. + hipDeviceAttributeMaxSurfaceCubemapLayered, ///< Cuda only. Maximum dimension of Cubemap layered surface. + hipDeviceAttributeMaxTexture1DWidth, ///< Maximum size of 1D texture. + hipDeviceAttributeMaxTexture1DLayered, ///< Cuda only. Maximum dimensions of 1D layered texture. + hipDeviceAttributeMaxTexture1DLinear, ///< Maximum number of elements allocatable in a 1D linear texture. + ///< Use cudaDeviceGetTexture1DLinearMaxWidth() instead on Cuda. + hipDeviceAttributeMaxTexture1DMipmap, ///< Cuda only. Maximum size of 1D mipmapped texture. + hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D texture. + hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension hight of 2D texture. + hipDeviceAttributeMaxTexture2DGather, ///< Cuda only. Maximum dimensions of 2D texture if gather operations performed. + hipDeviceAttributeMaxTexture2DLayered, ///< Cuda only. Maximum dimensions of 2D layered texture. + hipDeviceAttributeMaxTexture2DLinear, ///< Cuda only. Maximum dimensions (width, height, pitch) of 2D textures bound to pitched memory. + hipDeviceAttributeMaxTexture2DMipmap, ///< Cuda only. Maximum dimensions of 2D mipmapped texture. + hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D texture. + hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimension height of 3D texture. + hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimension depth of 3D texture. + hipDeviceAttributeMaxTexture3DAlt, ///< Cuda only. Maximum dimensions of alternate 3D texture. + hipDeviceAttributeMaxTextureCubemap, ///< Cuda only. Maximum dimensions of Cubemap texture + hipDeviceAttributeMaxTextureCubemapLayered, ///< Cuda only. Maximum dimensions of Cubemap layered texture. + hipDeviceAttributeMaxThreadsDim, ///< Maximum dimension of a block + hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per block. + hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads per multiprocessor. + hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory copies + hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits. + hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in kilohertz. + hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability version number. + hipDeviceAttributeMultiGpuBoardGroupID, ///< Cuda only. Unique ID of device group on the same multi-GPU board + hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the device. + hipDeviceAttributeName, ///< Device name. + hipDeviceAttributePageableMemoryAccess, ///< Device supports coherently accessing pageable memory + ///< without calling hipHostRegister on it + hipDeviceAttributePageableMemoryAccessUsesHostPageTables, ///< Device accesses pageable memory via the host's page tables + hipDeviceAttributePciBusId, ///< PCI Bus ID. + hipDeviceAttributePciDeviceId, ///< PCI Device ID. + hipDeviceAttributePciDomainID, ///< PCI Domain ID. + hipDeviceAttributePersistingL2CacheMaxSize, ///< Cuda11 only. Maximum l2 persisting lines capacity in bytes + hipDeviceAttributeMaxRegistersPerBlock, ///< 32-bit registers available to a thread block. This number is shared + ///< by all thread blocks simultaneously resident on a multiprocessor. + hipDeviceAttributeMaxRegistersPerMultiprocessor, ///< 32-bit registers available per block. + hipDeviceAttributeReservedSharedMemPerBlock, ///< Cuda11 only. Shared memory reserved by CUDA driver per block. + hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory available per block in bytes. + hipDeviceAttributeSharedMemPerBlockOptin, ///< Cuda only. Maximum shared memory per block usable by special opt in. + hipDeviceAttributeSharedMemPerMultiprocessor, ///< Cuda only. Shared memory available per multiprocessor. + hipDeviceAttributeSingleToDoublePrecisionPerfRatio, ///< Cuda only. Performance ratio of single precision to double precision. + hipDeviceAttributeStreamPrioritiesSupported, ///< Cuda only. Whether to support stream priorities. + hipDeviceAttributeSurfaceAlignment, ///< Cuda only. Alignment requirement for surfaces + hipDeviceAttributeTccDriver, ///< Cuda only. Whether device is a Tesla device using TCC driver + hipDeviceAttributeTextureAlignment, ///< Alignment requirement for textures + hipDeviceAttributeTexturePitchAlignment, ///< Pitch alignment requirement for 2D texture references bound to pitched memory; + hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes. + hipDeviceAttributeTotalGlobalMem, ///< Global memory available on devicice. + hipDeviceAttributeUnifiedAddressing, ///< Cuda only. An unified address space shared with the host. + hipDeviceAttributeUuid, ///< Cuda only. Unique ID in 16 byte. + hipDeviceAttributeWarpSize, ///< Warp size in threads. + hipDeviceAttributeMemoryPoolsSupported, ///< Device supports HIP Stream Ordered Memory Allocator + + hipDeviceAttributeCudaCompatibleEnd = 9999, + hipDeviceAttributeAmdSpecificBegin = 10000, + + hipDeviceAttributeClockInstructionRate = hipDeviceAttributeAmdSpecificBegin, ///< Frequency in khz of the timer used by the device-side "clock*" + hipDeviceAttributeArch, ///< Device architecture + hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory PerMultiprocessor. + hipDeviceAttributeGcnArch, ///< Device gcn architecture + hipDeviceAttributeGcnArchName, ///< Device gcnArch name in 256 bytes + hipDeviceAttributeHdpMemFlushCntl, ///< Address of the HDP_MEM_COHERENCY_FLUSH_CNTL register + hipDeviceAttributeHdpRegFlushCntl, ///< Address of the HDP_REG_COHERENCY_FLUSH_CNTL register + hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc, ///< Supports cooperative launch on multiple + ///< devices with unmatched functions + hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim, ///< Supports cooperative launch on multiple + ///< devices with unmatched grid dimensions + hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim, ///< Supports cooperative launch on multiple + ///< devices with unmatched block dimensions + hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem, ///< Supports cooperative launch on multiple + ///< devices with unmatched shared memories + hipDeviceAttributeIsLargeBar, ///< Whether it is LargeBar + hipDeviceAttributeAsicRevision, ///< Revision of the GPU in this device + hipDeviceAttributeCanUseStreamWaitValue, ///< '1' if Device supports hipStreamWaitValue32() and + ///< hipStreamWaitValue64(), '0' otherwise. + hipDeviceAttributeImageSupport, ///< '1' if Device supports image, '0' otherwise. + hipDeviceAttributePhysicalMultiProcessorCount, ///< All available physical compute + ///< units for the device + hipDeviceAttributeFineGrainSupport, ///< '1' if Device supports fine grain, '0' otherwise + + hipDeviceAttributeAmdSpecificEnd = 19999, + hipDeviceAttributeVendorSpecificBegin = 20000, + // Extended attributes for vendors +} hipDeviceAttribute_t; + +#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01) +#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02) +#define HIP_LAUNCH_PARAM_END ((void*)0x03) +// API-visible structures typedef struct ihipCtx_t* hipCtx_t; - // Note many APIs also use integer deviceIds as an alternative to the device pointer: typedef int hipDevice_t; - -typedef enum hipDeviceP2PAttr { - hipDevP2PAttrPerformanceRank = 0, - hipDevP2PAttrAccessSupported, - hipDevP2PAttrNativeAtomicSupported, - hipDevP2PAttrHipArrayAccessSupported -} hipDeviceP2PAttr; - typedef struct ihipStream_t* hipStream_t; - -#define hipIpcMemLazyEnablePeerAccess 0 - #define HIP_IPC_HANDLE_SIZE 64 - -typedef struct hipIpcMemHandle_st { - char reserved[HIP_IPC_HANDLE_SIZE]; -} hipIpcMemHandle_t; - -typedef struct hipIpcEventHandle_st { - char reserved[HIP_IPC_HANDLE_SIZE]; -} hipIpcEventHandle_t; - typedef struct ihipModule_t* hipModule_t; - typedef struct ihipModuleSymbol_t* hipFunction_t; typedef struct hipFuncAttributes { @@ -150,91 +317,8 @@ typedef struct hipFuncAttributes { int ptxVersion; size_t sharedSizeBytes; } hipFuncAttributes; - typedef struct ihipEvent_t* hipEvent_t; -/* - * @brief hipDeviceAttribute_t - * @enum - * @ingroup Enumerations - */ -typedef enum hipDeviceAttribute_t { - hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per block. - hipDeviceAttributeMaxBlockDimX, ///< Maximum x-dimension of a block. - hipDeviceAttributeMaxBlockDimY, ///< Maximum y-dimension of a block. - hipDeviceAttributeMaxBlockDimZ, ///< Maximum z-dimension of a block. - hipDeviceAttributeMaxGridDimX, ///< Maximum x-dimension of a grid. - hipDeviceAttributeMaxGridDimY, ///< Maximum y-dimension of a grid. - hipDeviceAttributeMaxGridDimZ, ///< Maximum z-dimension of a grid. - hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory available per block in - ///< bytes. - hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes. - hipDeviceAttributeWarpSize, ///< Warp size in threads. - hipDeviceAttributeMaxRegistersPerBlock, ///< Maximum number of 32-bit registers available to a - ///< thread block. This number is shared by all thread - ///< blocks simultaneously resident on a - ///< multiprocessor. - hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz. - hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in kilohertz. - hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits. - hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the device. - hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in. - hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device doesn't have L2 - ///< cache. - hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads per - ///< multiprocessor. - hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability version number. - hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability version number. - hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple kernels - ///< concurrently. - hipDeviceAttributePciBusId, ///< PCI Bus ID. - hipDeviceAttributePciDeviceId, ///< PCI Device ID. - hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory Per - ///< Multiprocessor. - hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices. - hipDeviceAttributeIntegrated, ///< iGPU - hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch - hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative launch on multiple devices - hipDeviceAttributeMaxTexture1DWidth, ///< Maximum number of elements in 1D images - hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D images in image elements - hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension height of 2D images in image elements - hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D images in image elements - hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimensions height of 3D images in image elements - hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimensions depth of 3D images in image elements - - hipDeviceAttributeHdpMemFlushCntl, ///< Address of the HDP_MEM_COHERENCY_FLUSH_CNTL register - hipDeviceAttributeHdpRegFlushCntl, ///< Address of the HDP_REG_COHERENCY_FLUSH_CNTL register - - hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory copies - hipDeviceAttributeTextureAlignment, ///(ptr)) if(ConstantInt* cst1 = dyn_cast(gep->idx_begin())) if(ConstantInt* cst2 = dyn_cast(off)){ - return (*builder_)->CreateGEP(gep->getPointerOperand(), - (*builder_)->CreateAdd(cst1, cst2)); + return (*builder_)->CreateGEP(gep->getPointerOperand()->getType()->getScalarType()->getPointerElementType(), + gep->getPointerOperand(), (*builder_)->CreateAdd(cst1, cst2)); } // ptr + (off + cst) -> (ptr + off) + cst if(auto* bin = dyn_cast(off)) if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) if(ConstantInt* cst = dyn_cast(bin->getOperand(1))){ - return (*builder_)->CreateGEP((*builder_)->CreateGEP(ptr, bin->getOperand(0)), - bin->getOperand(1)); + Value *gep = (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(), + ptr, bin->getOperand(0)); + return (*builder_)->CreateGEP(gep->getType()->getScalarType()->getPointerElementType(), + gep, bin->getOperand(1)); } // default - return (*builder_)->CreateGEP(ptr, off, name); + return (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(), + ptr, off, name); } //Value* geper::operator()(Type *ty, Value *ptr, std::vector vals, const std::string &name) { @@ -91,6 +94,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ // types #define void_ty builder_->getVoidTy() #define f16_ty builder_->getHalfTy() +#define bf16_ty builder_->getInt16Ty() #define f32_ty builder_->getFloatTy() #define f64_ty builder_->getDoubleTy() #define i8_ty builder_->getInt8Ty() @@ -124,7 +128,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__) #define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__) #define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__) -#define load(...) builder_->CreateLoad(__VA_ARGS__) +#define load(ptr) builder_->CreateLoad(ptr->getType()->getPointerElementType(), ptr) #define lshr(...) builder_->CreateLShr(__VA_ARGS__) #define max_num(...) builder_->CreateMaxNum(__VA_ARGS__) #define min_num(...) builder_->CreateMinNum(__VA_ARGS__) @@ -293,8 +297,8 @@ void generator::visit_phi_node(ir::phi_node* x) { */ void generator::visit_binary_operator(ir::binary_operator*x) { using ll = llvm::Instruction::BinaryOps; + using tt = ir::binary_op_t; auto cvt = [](ir::binary_op_t op){ - using tt = ir::binary_op_t; switch(op) { case tt::Add: return ll::Add; case tt::FAdd: return ll::FAdd; @@ -320,13 +324,47 @@ void generator::visit_binary_operator(ir::binary_operator*x) { for(indices_t idx: idxs_.at(x)){ Value *lhs = vals_[x->get_operand(0)][idx]; Value *rhs = vals_[x->get_operand(1)][idx]; - auto op = cvt(x->get_op()); - if(op == ll::Add) - vals_[x][idx] = add(lhs, rhs); - else if(op == ll::Mul) - vals_[x][idx] = mul(lhs, rhs); - else - vals_[x][idx] = bin_op(op, lhs, rhs); + if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) { + assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty()); + if (x->get_op() == tt::FAdd) { + InlineAsm *bf16_add_asm = + InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false), + "v_lshlrev_b32 $1, 16, $1 " + "v_lshlrev_b32 $2, 16, $2 " + "v_add_f32 $0, $1, $2 " + "v_lshrrev_b32 $0, 16, $0 ", + "=v,v,v", false); + vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs}); + } else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a + InlineAsm *bf16_sub_asm = + InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false), + "v_lshlrev_b32 $1, 16, $1 " + "v_lshlrev_b32 $2, 16, $2 " + "v_sub_f32 $0, $1, $2 " + "v_lshrrev_b32 $0, 16, $0 ", + "=v,v,v", false); + vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs}); + } else if (x->get_op() == tt::FMul) { // a * b = a*b + 0 + InlineAsm *bf16_mul_asm = + InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false), + "v_lshlrev_b32 $1, 16, $1 " + "v_lshlrev_b32 $2, 16, $2 " + "v_mul_f32 $0, $1, $2 " + "v_lshrrev_b32 $0, 16, $0 ", + "=v,v,v", false); + vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs}); + } else + throw std::runtime_error("invalid bin op for bf16"); + } + else { + auto op = cvt(x->get_op()); + if(op == ll::Add) + vals_[x][idx] = add(lhs, rhs); + else if(op == ll::Mul) + vals_[x][idx] = mul(lhs, rhs); + else + vals_[x][idx] = bin_op(op, lhs, rhs); + } } } @@ -979,6 +1017,35 @@ void generator::visit_log_inst(ir::log_inst* x){ /** * \brief Code Generation for `atomic_cas` */ +#if defined(USE_ROCM) +void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { + BasicBlock *current = builder_->GetInsertBlock(); + Module *module = current->getModule(); + Value *tid = tgt_->get_local_id(module, *builder_, 0); + Value *pred = icmp_eq(tid, i32(0)); + BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); + BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); + add_barrier(); + tgt_->add_memfence(module, *builder_); + cond_br(pred, tid_0_bb, tid_0_done_bb); + builder_->SetInsertPoint(tid_0_bb); + Value *cas_ptr = vals_[cas->get_operand(0)][{}]; + Value *cas_cmp = vals_[cas->get_operand(1)][{}]; + Value *cas_val = vals_[cas->get_operand(2)][{}]; + Value *old = atomic_cmp_xchg(cas_ptr, cas_cmp, cas_val, MaybeAlign(), AtomicOrdering::Monotonic, AtomicOrdering::Monotonic); + old = extract_val(old, std::vector{0}); + Value *atom_ptr; + atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), ""); + atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3)); + store(old, atom_ptr); + br(tid_0_done_bb); + builder_->SetInsertPoint(tid_0_done_bb); + tgt_->add_memfence(module, *builder_); + add_barrier(); + vals_[cas][{}] = load(atom_ptr); + add_barrier(); +} +#else void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { BasicBlock *current = builder_->GetInsertBlock(); Module *module = current->getModule(); @@ -1013,12 +1080,66 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { vals_[cas][{}] = load(atom_ptr); add_barrier(); } +#endif // defined(USE_ROCM) /** * \brief Code Generation for `atomic_rmw` */ +#if defined(USE_ROCM) void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { - ir::value* ptr = atom->get_operand(0); + if (atom->get_op() == ir::atomic_rmw_op_t::Add || + atom->get_op() == ir::atomic_rmw_op_t::Max || + atom->get_op() == ir::atomic_rmw_op_t::Min || + atom->get_op() == ir::atomic_rmw_op_t::UMax || + atom->get_op() == ir::atomic_rmw_op_t::UMin || + atom->get_op() == ir::atomic_rmw_op_t::Xchg) { + BasicBlock *current = builder_->GetInsertBlock(); + Module *module = current->getModule(); + Value *rmw_ptr = vals_[atom->get_operand(0)][{}]; + Value *rmw_val = vals_[atom->get_operand(1)][{}]; + Value *tid = tgt_->get_local_id(module, *builder_, 0); + Value *pred = icmp_eq(tid, i32(0)); + BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); + BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); + tgt_->add_memfence(module, *builder_); + add_barrier(); + cond_br(pred, tid_0_bb, tid_0_done_bb); + builder_->SetInsertPoint(tid_0_bb); + AtomicRMWInst::BinOp binop; + switch (atom->get_op()) { + case ir::atomic_rmw_op_t::Add: + binop = AtomicRMWInst::Add; + break; + case ir::atomic_rmw_op_t::Max: + binop = AtomicRMWInst::Max; + break; + case ir::atomic_rmw_op_t::Min: + binop = AtomicRMWInst::Min; + break; + case ir::atomic_rmw_op_t::UMax: + binop = AtomicRMWInst::UMax; + break; + case ir::atomic_rmw_op_t::UMin: + binop = AtomicRMWInst::UMin; + break; + case ir::atomic_rmw_op_t::Xchg: + binop = AtomicRMWInst::Xchg; + break; + default: + throw std::runtime_error("Not supported!"); + } + atomic_rmw(binop, rmw_ptr, rmw_val, MaybeAlign(), AtomicOrdering::Monotonic, + SyncScope::System); + br(tid_0_done_bb); + builder_->SetInsertPoint(tid_0_done_bb); + tgt_->add_memfence(module, *builder_); + return; + } + throw std::runtime_error("Not supported!"); +} +#else // defined(USE_ROCM) +void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { + ir::value *ptr = atom->get_operand(0); ir::value* val = atom->get_operand(1); ir::value* msk = atom->get_operand(2); @@ -1100,6 +1221,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { } } } +#endif // defined(USE_ROCM) /** * \brief Code Generation for `mma.884` (V100) diff --git a/lib/codegen/target.cc b/lib/codegen/target.cc index caa8d72f9..c775938ba 100644 --- a/lib/codegen/target.cc +++ b/lib/codegen/target.cc @@ -41,7 +41,8 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un } Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) { - throw std::runtime_error("not implemented on AMD"); + Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_waitcnt); + return builder.CreateIntrinsic(Intrinsic::amdgcn_s_waitcnt, {}, {builder.getInt32(0)}); } @@ -56,7 +57,50 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne } Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) { - throw std::runtime_error("not implemented on AMD"); + Function &F = *builder.GetInsertBlock()->getParent(); + Module *Mod = F.getParent(); + // We are indexing into this struct, and want to extract the grid_size_* + // fields. + // + // typedef struct hsa_kernel_dispatch_packet_s { + // uint16_t header; + // uint16_t setup; + // uint16_t workgroup_size_x ; + // uint16_t workgroup_size_y; + // uint16_t workgroup_size_z; + // uint16_t reserved0; + // uint32_t grid_size_x ; + // uint32_t grid_size_y ; + // uint32_t grid_size_z; + // ..... + // } hsa_kernel_dispatch_packet_t + // + Function *DispatchPtrFn = + Intrinsic::getDeclaration(Mod, Intrinsic::amdgcn_dispatch_ptr); + + CallInst *DispatchPtr = builder.CreateCall(DispatchPtrFn, {}); + DispatchPtr->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias); + DispatchPtr->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); + F.removeFnAttr("amdgpu-no-dispatch-ptr"); + + // Size of the dispatch packet struct. + DispatchPtr->addDereferenceableAttr(AttributeList::ReturnIndex, 64); + + Type *I32Ty = Type::getInt32Ty(Mod->getContext()); + // TODO: include AMDGPUAS:: declarations. + Value *CastDispatchPtr = builder.CreateBitCast( + DispatchPtr, PointerType::get(I32Ty, 4 /*AMDGPUAS::CONSTANT_ADDRESS*/)); + + // grid_size_x offset is 3*32bit + assert(ax < 3); + Value *GEP = + builder.CreateConstInBoundsGEP1_64(I32Ty, CastDispatchPtr, ax + 3); + LoadInst *Load = builder.CreateAlignedLoad(I32Ty, GEP, Align(4)); + + MDNode *MD = MDNode::get(Mod->getContext(), None); + Load->setMetadata(LLVMContext::MD_invariant_load, MD); + + return Load; // throw std::runtime_error("not implemented on AMD"); } Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) { diff --git a/lib/driver/dispatch.cc b/lib/driver/dispatch.cc index 4059ac235..ae43d2129 100755 --- a/lib/driver/dispatch.cc +++ b/lib/driver/dispatch.cc @@ -212,6 +212,7 @@ bool dispatch::hipinit(){ return res; } +#define HIP_DEFINE0(ret, fname) DEFINE0(hipinit, hip_, ret, fname) #define HIP_DEFINE1(ret, fname, t1) DEFINE1(hipinit, hip_, ret, fname, t1) #define HIP_DEFINE2(ret, fname, t1, t2) DEFINE2(hipinit, hip_, ret, fname, t1, t2) #define HIP_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3) @@ -268,7 +269,8 @@ HIP_DEFINE2(hipError_t, hipEventCreate, hipEvent_t *, unsigned int) HIP_DEFINE3(hipError_t, hipEventElapsedTime, float *, hipEvent_t, hipEvent_t) HIP_DEFINE2(hipError_t, hipEventRecord, hipEvent_t, hipStream_t) HIP_DEFINE1(hipError_t, hipEventDestroy, hipEvent_t) - +// error handling +HIP_DEFINE0(hipError_t, hipGetLastError) /* ------------------- * * COMMON diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 3288bf9f1..c06e766e5 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -268,7 +268,7 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) { std::string triple = "amdgcn-amd-amdhsa"; std::string layout = ""; std::string features="+sramecc,-xnack"; - std::string proc = "gfx908"; + std::string proc = STRINGIFY(MI_GPU_ARCH); // name kernel auto in_time_t = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); std::stringstream cur_time; diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index e1d51856e..8644cd552 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -47,6 +47,8 @@ ir::type *computation_type(ir::type* a_ty, ir::type* b_ty){ // converted to half if(a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) return type::get_fp16_ty(ctx); + if(a_ty->is_bf16_ty() || b_ty->is_bf16_ty()) + return type::get_bf16_ty(ctx); if(!a_ty->is_integer_ty() || !b_ty->is_integer_ty()) throw_unreachable("augment_types"); // 4 ) both operands are integer and undergo diff --git a/python/src/triton.cc b/python/src/triton.cc index 7bb94d5fe..911f9c21a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -97,7 +97,9 @@ void hip_enqueue(uint64_t stream, uint64_t kernel, drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, shared_mem, (hipStream_t)stream, nullptr, config); - +#ifdef DEBUG_ROCM + drv::dispatch::hipGetLastError(); +#endif } void init_triton_runtime(py::module &&m) { @@ -249,7 +251,7 @@ std::tuple hip_compile_ttir(const std::string& name llir.flush(); asm_map["llir"] = py::cast(tmp); // LLVM-IR -> HSA-CO - std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908"); + std::string path = drv::llir_to_amdgpu(llvm.get(), STRINGIFY(MI_GPU_ARCH)); asm_map["hsaco"] = py::cast(path); return std::make_tuple(name, asm_map, n_shared_bytes); } @@ -266,14 +268,14 @@ void init_triton_codegen(py::module &&m) { llvm::LLVMContext ctx; if(backend == CUDA) return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); - if(backend == ROCM) - return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); + assert(backend == ROCM); + return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); }, py::return_value_policy::take_ownership); m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ if(backend == CUDA) return cu_load_binary(name, asm_map, n_shared_bytes, dev); - if(backend == ROCM) - return hip_load_binary(name, asm_map, n_shared_bytes, dev); + assert(backend == ROCM); + return hip_load_binary(name, asm_map, n_shared_bytes, dev); }, py::return_value_policy::take_ownership); } diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index e205828d6..eecc5f915 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -49,7 +49,7 @@ matmul_data = { @pytest.mark.parametrize('M, N, K', matmul_data.keys()) def test_matmul(M, N, K): ref_gpu_util = matmul_data[(M, N, K)]['v100'] - cur_sm_clock = nvsmi(['clocks.current.sm'])[0] + cur_sm_clock = 1350 #nvsmi(['clocks.current.sm'])[0] ref_sm_clock = 1350 max_gpu_perf = 1e-6*80*8*128*cur_sm_clock assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz' @@ -92,7 +92,7 @@ elementwise_data = { @pytest.mark.parametrize('N', elementwise_data.keys()) def test_elementwise(N): ref_gpu_util = elementwise_data[N]['v100'] - cur_mem_clock = nvsmi(['clocks.current.memory'])[0] + cur_mem_clock = 877 #nvsmi(['clocks.current.memory'])[0] ref_mem_clock = 877 max_gpu_perf = 512*2*ref_mem_clock*1e-3 assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz' diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index f99a632a9..96011f1bd 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -369,9 +369,6 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'): ('float32', 'int32', True) ]) def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): - if torch.version.hip is not None: - assert 'bfloat' not in dtype_x - assert 'bfloat' not in dtype_z SIZE = 1024 x = triton.testing.random((SIZE, ), dtype=cvt[dtype_x], device=device) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index dbf1974ce..c01c797ad 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -86,4 +86,4 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, # run test th_c = torch.matmul(a, b) tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest) - triton.testing.assert_almost_equal(th_c, tt_c) + triton.testing.assert_almost_equal(th_c, tt_c) \ No newline at end of file diff --git a/python/triton/testing.py b/python/triton/testing.py index 08ad62580..479c958b2 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -61,8 +61,12 @@ def mask_tensor(x, mask, block, value=0): def assert_almost_equal(x, y, decimal=2, err_msg=''): import numpy.testing as npt if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() x = x.cpu().detach().numpy() if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() y = y.cpu().detach().numpy() npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal) @@ -97,7 +101,7 @@ def random(shape, dtype, device): return torch.randint(0, 2, shape, dtype=dtype, device=device) if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: return torch.randint(1, 32, shape, dtype=dtype, device=device) - if dtype in [torch.float16, torch.float32, torch.float64]: + if dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64]: return torch.normal(0, 1, shape, dtype=dtype, device=device) raise RuntimeError(f'Unknown dtype {dtype}') diff --git a/triton_rocm_20-52.Dockerfile b/triton_rocm_20-52.Dockerfile new file mode 100644 index 000000000..aab85165c --- /dev/null +++ b/triton_rocm_20-52.Dockerfile @@ -0,0 +1,14 @@ +FROM rocm/pytorch:rocm5.2.3_ubuntu20.04_py3.7_pytorch_1.12.1 + +# build triton +RUN export TRITON_USE_ROCM=ON MI_GPU_ARCH=gfx90a + +# Unit Tests +# to run unit tests +# 1. build this Dockerfile +# docker build --build-arg -f triton_rocm_20-52.Dockerfile -t triton_rocm52 . +# 2. run docker container +# docker run -it --rm --network=host --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --name triton --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri triton_rocm52:latest +# 3. run core unit tests on a rocm machine +# cd ~/triton/python +# pytest --verbose test/unit/language/test_core.py | tee test_core.log