[ROCM] enable matmul(dot) and others (#391)

This commit is contained in:
Michael Melesse
2021-12-13 12:28:15 -08:00
committed by GitHub
parent 73b04d71b2
commit 94d5c2e8b5
12 changed files with 251 additions and 52 deletions

View File

@@ -7,6 +7,8 @@ if(NOT TRITON_LLVM_BUILD_DIR)
set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
endif()
set(TRITON_USE_ROCM "$ENV{TRITON_USE_ROCM}")
set(TRITON_ROCM_DEBUG "$ENV{TRITON_ROCM_DEBUG}")
project(triton)
include(CTest)
@@ -37,7 +39,11 @@ if(WIN32)
add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
if (TRITON_USE_ROCM)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-unused-result -Wno-attributes")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
endif()
##########
@@ -110,6 +116,15 @@ libLLVMBinaryFormat.a
libLLVMAMDGPUInfo.a
libLLVMSupport.a
libLLVMDemangle.a
libLLVMPasses.a
libLLVMAnalysis.a
libLLVMTransformUtils.a
libLLVMScalarOpts.a
libLLVMTransformUtils.a
libLLVMipo.a
libLLVMObjCARCOpts.a
libLLVMCoroutines.a
libLLVMAnalysis.a
)
endif()
include_directories("${LLVM_INCLUDE_DIRS}")
@@ -128,6 +143,13 @@ if(BUILD_PYTHON_MODULE)
endif()
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
if (TRITON_USE_ROCM)
add_definitions(-DUSE_ROCM)
endif()
if (TRITON_ROCM_DEBUG)
add_definitions(-DDEBUG_ROCM)
endif()
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC})
endif()

View File

@@ -11,7 +11,6 @@
#include "triton/external/CUDA/nvml.h"
//// HIP backend
//#define __HIP_PLATFORM_AMD__
#include "triton/external/hip.h"
//Exceptions

View File

@@ -42,6 +42,7 @@ public:
value *get_int64(int64_t val);
value *get_float16(float val);
value *get_float32(float val);
value *get_float64(float val);
value *get_range(int32_t lo, int32_t hi);
// Types
type *get_void_ty();

View File

@@ -145,7 +145,7 @@ mma_layout::mma_layout(size_t num_warps,
shared_layout *layout_a, shared_layout *layout_b): distributed_layout(MMA, axes, shape, values, align) {
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
if(tgt->as_nvidia()->sm() < 80){
if(tgt->as_nvidia() && tgt->as_nvidia()->sm() < 80){
fpw_ = {2, 2, 1};
auto ord_a = layout_a->get_order();
auto ord_b = layout_b->get_order();

View File

@@ -26,6 +26,7 @@ namespace codegen {
// There should be a proper pass manager there!
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target,
int cc, int num_warps, int num_stages, int& shared_static) {
// generate llvm code
std::string name = ir.get_function_list()[0]->get_name();
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));

View File

@@ -14,7 +14,13 @@
#include "triton/ir/type.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IR/Type.h"
#ifdef USE_ROCM
#include "llvm/IR/IntrinsicsAMDGPU.h"
#else
#include "llvm/IR/IntrinsicsNVPTX.h"
#endif
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/InlineAsm.h"
@@ -86,6 +92,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
#define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy()
#define f32_ty builder_->getFloatTy()
#define f64_ty builder_->getDoubleTy()
#define i8_ty builder_->getInt8Ty()
#define i32_ty builder_->getInt32Ty()
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
@@ -464,7 +471,7 @@ Value* generator::bf16_to_fp32(Value *in0){
}
Value* generator::fp32_to_bf16(Value *in0){
if(tgt_->as_nvidia()->sm() >= 80){
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80){
InlineAsm *ptx = InlineAsm::get(FunctionType::get(builder_->getInt16Ty(), {builder_->getFloatTy()}, false),
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
return call(ptx, {in0});
@@ -584,6 +591,22 @@ void generator::visit_load_inst(ir::load_inst* x){
ir::value *op = x->get_pointer_operand();
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
#ifdef USE_ROCM
// code generation
auto idxs = idxs_.at(x);
for(size_t i = 0; i <idxs.size(); i += 1){
indices_t idx = idxs[i];
// pointer value
Value *ptr = vals_[op][idx];
// create load
Value *_ret = builder_->CreateLoad(ty, ptr);
// upload to global vals map
vals_[x][idx] = _ret;
}
#else
// compute vector width
size_t vec = 1;
if(op->get_type()->is_block_ty()){
@@ -715,6 +738,7 @@ void generator::visit_load_inst(ir::load_inst* x){
for(size_t ii = 0; ii < vec; ii++)
vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp);
}
#endif
}
void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
@@ -733,6 +757,23 @@ void generator::visit_store_inst(ir::store_inst * x){
// operands
ir::value *ptr_op = x->get_pointer_operand();
ir::value *val_op = x->get_value_operand();
#ifdef USE_ROCM
auto idxs = idxs_.at(val_op);
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
for (size_t i = 0; i < idxs.size(); i += 1)
{
auto idx = idxs[i];
// pointer
Value *ptr = vals_[ptr_op][idx];
// value
Value *val = vals_.at(val_op)[idxs[i]];
// store value at pointer
store(val, ptr);
}
#else
// vector size
size_t vec = 1;
if(val_op->get_type()->is_block_ty()){
@@ -766,6 +807,7 @@ void generator::visit_store_inst(ir::store_inst * x){
else
store(val, ptr);
}
#endif
}
void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) {
visit_store_inst(x);
@@ -858,7 +900,12 @@ void generator::visit_exp_inst(ir::exp_inst* x){
Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
std::vector<llvm::Type*> tys = {f32_ty};
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
#ifdef USE_ROCM
llvm::Function *ex2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::exp2, tys);
#else
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
#endif
for(auto idx: idxs_.at(x)){
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
@@ -871,7 +918,11 @@ void generator::visit_exp_inst(ir::exp_inst* x){
void generator::visit_cos_inst(ir::cos_inst* x){
std::vector<llvm::Type*> tys = {f32_ty};
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
#ifdef USE_ROCM
llvm::Function *cos = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::cos, tys);
#else
InlineAsm *cos = InlineAsm::get(fn_ty, "cos.approx.f32 $0, $0;", "=f,0", false);
#endif
for(auto idx: idxs_.at(x)){
vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
}
@@ -897,7 +948,11 @@ void generator::visit_umulhi_inst(ir::umulhi_inst* x){
void generator::visit_sin_inst(ir::sin_inst* x){
std::vector<llvm::Type*> tys = {f32_ty};
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
#ifdef USE_ROCM
llvm::Function *sin = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::sin, tys);
#else
InlineAsm *sin = InlineAsm::get(fn_ty, "sin.approx.f32 $0, $0;", "=f,0", false);
#endif
for(auto idx: idxs_.at(x)){
vals_[x][idx] = call(sin, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
}
@@ -910,7 +965,11 @@ void generator::visit_log_inst(ir::log_inst* x){
Constant *rcplog2e = ConstantFP::get(f32_ty, 0.6931471805599453);
std::vector<llvm::Type*> tys = {f32_ty};
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
#ifdef USE_ROCM
llvm::Function *lg2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::log2, tys);
#else
InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false);
#endif
for(auto idx: idxs_.at(x)){
Value *lg2arg = call(lg2, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
vals_[x][idx] = fmul(lg2arg, rcplog2e);
@@ -1701,10 +1760,14 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
size_t red_axis = 1;
unsigned NK = A_shapes[red_axis];
bool is_outer = NK == 1;
#ifdef USE_ROCM
bool is_mma = layouts_->get(dot)->to_mma();
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80)
#else
bool is_mma = false;
#endif
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80)
return visit_mma884(dot, A, B, D, NK);
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
return visit_mma16816(dot, A, B, D, NK);
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
}
@@ -1739,8 +1802,14 @@ Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vec
inline Value* generator::shfl_sync(Value* acc, int32_t i){
Type* ty = acc->getType();
#ifdef USE_ROCM
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
#else
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
#endif
if(ty->getPrimitiveSizeInBits() <= 32)
return call(shfl, {acc, i32(i)});
acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2));
@@ -1902,8 +1971,14 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
default: throw std::runtime_error("unreachable");
}
ir::value *arg = x->get_operand(0);
if(arg->get_type()->get_tile_rank() == 1)
if (arg->get_type()->get_tile_rank() == 1)
{
#ifdef USE_ROCM
visit_reducend_inst(x, do_acc, neutral);
#else
visit_reduce1d_inst(x, do_acc, neutral);
#endif
}
else
visit_reducend_inst(x, do_acc, neutral);
}
@@ -2286,12 +2361,14 @@ void generator::visit_function(ir::function* fn) {
// set metadata
if(tgt_->is_gpu()){
tgt_->set_kernel(*builder_, ctx, mod_, ret);
#ifndef USE_ROCM
Metadata *md_args[] = {
ValueAsMetadata::get(ret),
MDString::get(ctx, "maxntidx"),
ValueAsMetadata::get(i32(num_warps_*32))
};
mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
#endif
}
// set arguments
for(unsigned i = 0; i < fn->args().size(); i++)
@@ -2311,6 +2388,9 @@ void generator::visit_function(ir::function* fn) {
visit_basic_block(block);
// finalize
finalize_function(fn);
// verifyFunction
llvm::verifyFunction(*ret);
}
@@ -2334,7 +2414,11 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
Value *_8 = i32(8);
Value *_16 = i32(16);
Value *_32 = i32(32);
#ifdef USE_ROCM
int cc = 1; // generate ir for older CUDA cards
#else
int cc = tgt_->as_nvidia()->sm();
#endif
std::vector<Value*> idx_m;
std::vector<Value*> idx_n;
std::vector<Value*> idx_z;

View File

@@ -41,7 +41,7 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un
}
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
throw std::runtime_error("not implemented");
throw std::runtime_error("not implemented on AMD");
}
@@ -156,7 +156,7 @@ Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsi
}
Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
throw std::runtime_error("not implemented");
throw std::runtime_error("not implemented on CPU");
}

View File

@@ -25,6 +25,7 @@
#endif
#include <memory>
#include <regex>
#include <iomanip>
#include "triton/driver/llvm.h"
#include "triton/driver/dispatch.h"
#include "triton/driver/error.h"
@@ -56,6 +57,8 @@
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/Intrinsics.h"
// end AMD stuff
namespace triton{
@@ -264,8 +267,13 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
llvm::SmallVector<char, 0> buffer;
std::string triple = "amdgcn-amd-amdhsa";
std::string layout = "";
std::string features;
std::string features="+sramecc,-xnack";
std::string proc = "gfx908";
// name kernel
auto in_time_t = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
std::stringstream cur_time;
cur_time << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d--%I-%M-%S");
std::string kernel_name = module->getModuleIdentifier() + "_" + cur_time.str();
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
@@ -281,7 +289,7 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None,
llvm::CodeGenOpt::Aggressive);
llvm::CodeGenOpt::None);
// set data layout
if(layout.empty())
module->setDataLayout(machine->createDataLayout());
@@ -294,11 +302,10 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
llvm::raw_svector_ostream stream(buffer);
// create dump files
std::string module_name = module->getModuleIdentifier();
std::error_code ec;
// Save GCN ISA binary.
std::string isabin_path = std::string("/tmp/") + module_name + std::string(".o");
std::string isabin_path = std::string("/tmp/") + kernel_name + std::string(".o");
std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
if (ec)
@@ -309,15 +316,25 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
// emit
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
pass.run(*module);
#ifdef DEBUG_ROCM
std::cout << "Generating GCN ISA file" << std::endl;
llvm::SmallVector<char, 0> debugBuffer;
llvm::legacy::PassManager debugPass;
llvm::raw_svector_ostream debugStream(debugBuffer);
machine->addPassesToEmitFile(debugPass, debugStream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile); // TODO:cause segfault on REM ops also cause @llvm.amdgcn.if bug
debugPass.run(*module);
// Save GCN ISA.
std::string amdgcn_path = std::string("/tmp/") + module_name + std::string(".gcn");
std::string result(buffer.begin(), buffer.end());
std::string amdgcn_path = std::string("/tmp/") + kernel_name + std::string(".gcn");
std::string result(debugBuffer.begin(), debugBuffer.end());
std::ofstream amdgcn(amdgcn_path);
amdgcn << result;
amdgcn.close();
#endif
// generate HASCO file
std::string hsaco_path = std::string("/tmp/") + module_name + std::string(".hsaco");
std::string hsaco_path = std::string("/tmp/") + kernel_name + std::string(".hsaco");
std::string error_message;
int lld_result =
llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld",

View File

@@ -60,6 +60,9 @@ value *builder::get_float16(float val)
value *builder::get_float32(float val)
{ return constant_fp::get(type::get_fp32_ty(ctx_), val); }
value *builder::get_float64(float val)
{ return constant_fp::get(type::get_fp64_ty(ctx_), val); }
value *builder::get_range(int32_t _lo, int32_t _hi) {
constant_int* lo = static_cast<constant_int*>(get_int32(_lo));
constant_int* hi = static_cast<constant_int*>(get_int32(_hi));

View File

@@ -559,7 +559,27 @@ ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::
ptr = dispatch::cast(ptr, ptr_ty, builder);
}
// cast to target data-type
#ifdef USE_ROCM
ir::type *src_ty = val->get_type();
ir::type *dst_ty = elt_ty;
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, src_ty->get_block_shapes());
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
// check if truncation is need
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
if (truncate_fp && elt_ty->is_fp16_ty())
{
std::cout << "WARNING: "<<"casting down to fp16 is broken on ROCM" << std::endl;
}
val = dispatch::cast(val, elt_ty, builder);
#else
val = dispatch::cast(val, elt_ty, builder);
#endif
if (!mask)
return builder->create_store(ptr, val);
if(!mask->get_type()->get_scalar_ty()->is_bool_ty())
@@ -659,6 +679,8 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask
//===----------------------------------------------------------------------===//
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
lhs = dispatch::cast(lhs, builder->get_float_ty(), builder);
rhs = dispatch::cast(rhs, builder->get_float_ty(), builder);
ir::value *_0 = builder->get_float32(0);
unsigned M = lhs->get_type()->get_block_shapes()[0];
unsigned N = rhs->get_type()->get_block_shapes()[1];

View File

@@ -15,10 +15,11 @@ from setuptools.command.test import test as TestCommand
import distutils.spawn
import urllib.request
import tarfile
import torch
def get_llvm():
# tries to find system LLVM
versions = ['-11.0', '-11', '-11-64']
versions = ['-13.0', '-13', '-13-64']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None]
@@ -27,7 +28,7 @@ def get_llvm():
if platform.system() == "Windows":
return '', ''
# download if nothing is installed
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
name = 'clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04'
dir = '/tmp'
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
@@ -36,7 +37,7 @@ def get_llvm():
shutil.rmtree(os.path.join(dir, name))
except:
pass
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-13.0.0/{name}.tar.xz".format(name=name)
print('downloading and extracting ' + url + '...')
ftpstream = urllib.request.urlopen(url)
file = tarfile.open(fileobj=ftpstream, mode="r|xz")
@@ -80,7 +81,7 @@ class CMakeBuild(build_ext):
def build_extension(self, ext):
llvm_include_dir, llvm_library_dir = get_llvm()
# self.debug = True
self.debug = True
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# create build directories
build_suffix = 'debug' if self.debug else 'release'
@@ -90,7 +91,10 @@ class CMakeBuild(build_ext):
if not os.path.exists(llvm_build_dir):
os.makedirs(llvm_build_dir)
# python directories
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
if torch.version.hip is not None:
python_include_dirs= [distutils.sysconfig.get_python_inc()] +['/opt/rocm/include']
else:
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
cmake_args = [
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DBUILD_TUTORIALS=OFF",
@@ -117,6 +121,9 @@ class CMakeBuild(build_ext):
build_args += ["--", '-j' + str(2 * multiprocessing.cpu_count())]
env = os.environ.copy()
if torch.version.hip is not None:
env["TRITON_USE_ROCM"] = "ON"
subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)

View File

@@ -45,6 +45,35 @@ def test_empty_kernel(dtype_x, device='cuda'):
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
# ---------------
# test load and store op
# ---------------
@pytest.mark.parametrize("dtype,size", [
(dtype, size)
for dtype in dtypes
for size in [128, 256, 512, 1024, 2048, 4096]
])
def test_load_and_store_op(dtype, size, device='cuda'):
SIZE = size
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, **meta):
off = tl.arange(0, meta['SIZE'])
x = tl.load(X + off)
tl.store(Z + off, x)
# inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype], device=device)
# output tensors
z_ref = x.clone() # reference result
z_tri = torch.empty_like(x) # triton result
# run load and store kernel
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
# compare
triton.testing.assert_almost_equal(z_ref, z_tri)
# generic test functions
def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
SIZE = 128
@@ -340,18 +369,23 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
('float32', 'int32', True)
])
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device)
if torch.version.hip is not None:
assert 'bfloat' not in dtype_x
assert 'bfloat' not in dtype_z
SIZE = 1024
x = triton.testing.random((SIZE, ), dtype=cvt[dtype_x], device=device)
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
x = tl.load(X)
off = tl.arange(0, meta['SIZE'])
x = tl.load(X+ off)
z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST'])
tl.store(Z, z)
tl.store(Z+ off, z)
# triton result
z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device)
kernel[(1, )](x, z_tri, BITCAST=bitcast)
z_tri = torch.empty((SIZE, ), dtype=cvt[dtype_z], device=device)
kernel[(1, )](x, z_tri, SIZE=SIZE, BITCAST=bitcast)
# torch result
if bitcast:
import numpy as np
@@ -359,7 +393,7 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
z_ref = torch.from_numpy(z_ref).to(device)
else:
z_ref = x.to(z_tri.dtype)
assert z_tri == z_ref
triton.testing.assert_almost_equal(z_ref, z_tri)
# ---------------
# test reduce
@@ -448,17 +482,23 @@ def test_permute(dtype, shape, perm, device='cuda'):
z_ref = x.permute(*perm).contiguous()
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if torch.version.hip is None:
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
# ---------------
# test dot
# ---------------
@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, device='cuda'):
@pytest.mark.parametrize("dtype, epilogue", [(dtype, epilogue)\
for dtype in ['float16','float32'] \
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols']])
def test_dot(dtype, epilogue, device='cuda'):
dtype = cvt[dtype]
torch.manual_seed(0)
# triton kernel
@triton.jit
@@ -486,10 +526,10 @@ def test_dot(epilogue, device='cuda'):
tl.store(Zs, z)
# input
M, N, K = 64, 64, 32
x = triton.testing.random((M, K), dtype=torch.float16, device=device)
y = triton.testing.random((K, N), dtype=torch.float16, device=device)
x = triton.testing.random((M, K), dtype=dtype, device=device)
y = triton.testing.random((K, N), dtype=dtype, device=device)
# triton result
z = triton.testing.random((M, N), dtype=torch.float16, device=device)
z = triton.testing.random((M, N), dtype=dtype, device=device)
z_tri = z.clone()
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
y, y.stride(0), y.stride(1),
@@ -508,12 +548,14 @@ def test_dot(epilogue, device='cuda'):
z_ref += z[0,:][None, :]
z_ref = z_ref.to(torch.float16)
# compare
ptx = pgm.asm['ptx']
# print(ptx)
triton.testing.assert_almost_equal(z_tri, z_ref)
# make sure ld/st are vectorized
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
# print(ptx)
if torch.version.hip is None:
ptx = pgm.asm['ptx']
# make sure ld/st are vectorized
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
def test_dot_without_load():
@triton.jit
@@ -611,17 +653,18 @@ def test_load_cache_modifier(cache):
tl.store(dst+offsets, x)
pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx']
if torch.version.hip is None:
ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
if cache == '.cg':
assert 'ld.global.cg' in ptx
assert 'ld.global.ca' not in ptx
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
if cache == '.cg':
assert 'ld.global.cg' in ptx
assert 'ld.global.ca' not in ptx
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
# ---------------
# test store
@@ -647,4 +690,4 @@ def test_noop(device='cuda'):
def kernel(**meta):
pass
x = triton.testing.random((1,), dtype=torch.int32, device=device)
kernel[(1, )](x)
kernel[(1, )](x)