[GENERAL] Removed deprecated driver files and added basic compatibility with rocm (#268)
- Removed driver module -- accelerator runtime is handled by pytorch - Added basic support for ROCM based on @micmelesse 's PR -- now can execute empty kernel on AMD devices without any compile-time changes - Now only using PREFER_SHARED for kernels when the size of shared memory is greater than 49k. Otherwise there can be poor L1 performance for broadcast tensors
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#include "triton/codegen/pass.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/dispatch.h"
|
||||
#include "triton/ir/enums.h"
|
||||
@@ -15,7 +15,9 @@
|
||||
#include <pybind11/stl.h>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace ir = triton::ir;
|
||||
@@ -24,72 +26,213 @@ namespace drv = triton::driver;
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::driver */
|
||||
/*****************************************************************************/
|
||||
// information query
|
||||
template<CUdevice_attribute attr>
|
||||
int cuGetInfo(CUdevice device) {
|
||||
int res;
|
||||
drv::dispatch::cuDeviceGetAttribute(&res, attr, device);
|
||||
return res;
|
||||
}
|
||||
|
||||
void init_triton_driver(py::module &&m) {
|
||||
// base device
|
||||
py::class_<drv::device>(m, "device");
|
||||
// cuda device
|
||||
py::class_<drv::cu_device, drv::device>(m, "cu_device")
|
||||
.def(py::init([](int dev_id, bool take_ownership) {
|
||||
CUdevice handle;
|
||||
drv::dispatch::cuDeviceGet(&handle, dev_id);
|
||||
return new drv::cu_device(handle, take_ownership);
|
||||
}))
|
||||
.def("max_shared_memory", [](drv::cu_device *self) {
|
||||
return self->max_shared_memory();
|
||||
})
|
||||
.def("enable_peer_access", [](drv::cu_device *self, unsigned long long int peer_mem_ptr) {
|
||||
self->enable_peer_access(peer_mem_ptr);
|
||||
});
|
||||
// host device
|
||||
py::class_<drv::host_device, drv::device>(m, "host_device")
|
||||
.def(py::init<>());
|
||||
template<hipDeviceAttribute_t attr>
|
||||
int hipGetInfo(hipDevice_t device) {
|
||||
int res;
|
||||
drv::dispatch::hipDeviceGetAttribute(&res, attr, device);
|
||||
return res;
|
||||
}
|
||||
|
||||
// base stream
|
||||
py::class_<drv::stream>(m, "stream");
|
||||
// host stream
|
||||
py::class_<drv::host_stream, drv::stream>(m, "host_stream")
|
||||
.def(py::init<>());
|
||||
// cuda stream
|
||||
py::class_<drv::cu_stream, drv::stream>(m, "cu_stream")
|
||||
// py doesn't support opaque pointer (e.g., CUstream) so
|
||||
// we assume it has been converted to uint64_t
|
||||
.def(py::init([](uint64_t handle, bool take_ownership) {
|
||||
return std::unique_ptr<drv::cu_stream>(new drv::cu_stream((CUstream)handle, take_ownership));
|
||||
}))
|
||||
.def("enqueue", [](drv::cu_stream *self, drv::kernel *kernel,
|
||||
size_t grid_0, size_t grid_1, size_t grid_2,
|
||||
size_t block_0, size_t block_1, size_t block_2,
|
||||
const std::string &args,
|
||||
size_t shared_mem) {
|
||||
return self->enqueue(kernel, {grid_0, grid_1, grid_2}, {block_0, block_1, block_2},
|
||||
(void *)args.data(), args.size(), shared_mem);
|
||||
});
|
||||
enum backend_t {
|
||||
HOST,
|
||||
CUDA,
|
||||
ROCM,
|
||||
};
|
||||
|
||||
py::class_<drv::module>(m, "module");
|
||||
void cu_enable_peer_access(uint64_t peer_ptr){
|
||||
CUcontext context;
|
||||
drv::dispatch::cuPointerGetAttribute(&context, CU_POINTER_ATTRIBUTE_CONTEXT, peer_ptr);
|
||||
try {
|
||||
drv::dispatch::cuCtxEnablePeerAccess(context, 0);
|
||||
} catch (drv::exception::cuda::peer_access_already_enabled) {}
|
||||
}
|
||||
|
||||
py::class_<drv::cu_module, drv::module>(m, "cu_module")
|
||||
.def("ptx", &drv::cu_module::ptx)
|
||||
.def("cubin", [](drv::cu_module *self) { return py::bytes(self->cubin()); })
|
||||
.def("llir", &drv::cu_module::llir);
|
||||
void host_enqueue(uint64_t stream, uint64_t kernel,
|
||||
uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
|
||||
uint64_t block_0, uint64_t block_1, uint64_t block_2,
|
||||
void* args_ptr, size_t args_size, int64_t shared_mem){
|
||||
throw std::runtime_error("unsupported");
|
||||
// auto hst = kernel->module()->hst();
|
||||
// hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]);
|
||||
// char* params = new char[args_size];
|
||||
// std::memcpy((void*)params, (void*)args, args_size);
|
||||
// for(size_t i = 0; i < grid[0]; i++)
|
||||
// for(size_t j = 0; j < grid[1]; j++)
|
||||
// for(size_t k = 0; k < grid[2]; k++)
|
||||
// hst_->futures->emplace_back(hst_->pool->enqueue(hst->fn, (char**)params, int32_t(i), int32_t(j), int32_t(k)));
|
||||
}
|
||||
|
||||
py::class_<drv::kernel>(m, "kernel");
|
||||
void cu_enqueue(uint64_t stream, uint64_t kernel,
|
||||
uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
|
||||
uint64_t block_0, uint64_t block_1, uint64_t block_2,
|
||||
void* args_ptr, size_t args_size, int64_t shared_mem){
|
||||
void *config[] = {
|
||||
CU_LAUNCH_PARAM_BUFFER_POINTER, (void*)args_ptr,
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||
CU_LAUNCH_PARAM_END
|
||||
};
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
block_0, block_1, block_2,
|
||||
shared_mem, (CUstream)stream, nullptr, config);
|
||||
}
|
||||
|
||||
void hip_enqueue(uint64_t stream, uint64_t kernel,
|
||||
uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
|
||||
uint64_t block_0, uint64_t block_1, uint64_t block_2,
|
||||
void* args_ptr, size_t args_size, int64_t shared_mem) {
|
||||
void *config[] = {
|
||||
HIP_LAUNCH_PARAM_BUFFER_POINTER, (void*)args_ptr,
|
||||
HIP_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||
HIP_LAUNCH_PARAM_END
|
||||
};
|
||||
drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2,
|
||||
block_0, block_1, block_2,
|
||||
shared_mem, (hipStream_t)stream, nullptr, config);
|
||||
|
||||
}
|
||||
|
||||
void init_triton_runtime(py::module &&m) {
|
||||
|
||||
// wrap backend_t
|
||||
py::enum_<backend_t>(m, "backend")
|
||||
.value("HOST", HOST)
|
||||
.value("CUDA", CUDA)
|
||||
.value("ROCM", ROCM)
|
||||
.export_values();
|
||||
|
||||
// enable peer-to-peer
|
||||
m.def("enable_peer_access", [](backend_t backend, uint64_t peer_ptr) {
|
||||
if (backend != CUDA)
|
||||
throw std::runtime_error("P2P only supported on CUDA devices!");
|
||||
cu_enable_peer_access(peer_ptr);
|
||||
}
|
||||
);
|
||||
|
||||
// query maximum shared memory
|
||||
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
|
||||
if (backend == HOST)
|
||||
return 0;
|
||||
if(backend == CUDA)
|
||||
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(device);
|
||||
if(backend == ROCM)
|
||||
return hipGetInfo<hipDeviceAttributeMaxSharedMemoryPerBlock>(device);
|
||||
return -1;
|
||||
});
|
||||
|
||||
// enqueue
|
||||
m.def("enqueue", [](backend_t backend, uint64_t stream, uint64_t kernel,
|
||||
uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
|
||||
uint64_t block_0, uint64_t block_1, uint64_t block_2,
|
||||
const std::string &args, int64_t shared_mem){
|
||||
void* args_ptr = (void*)args.data();
|
||||
size_t args_size = args.size();
|
||||
if(backend == HOST)
|
||||
host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
||||
if(backend == CUDA)
|
||||
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
||||
if(backend == ROCM)
|
||||
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
||||
});
|
||||
|
||||
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::codegen */
|
||||
/*****************************************************************************/
|
||||
typedef std::map<std::string, std::string> asm_map_t;
|
||||
|
||||
|
||||
std::tuple<uint64_t, uint64_t> cu_compile_llir(const std::string& name, size_t n_shared_bytes, llvm::Module* llvm, uint64_t dev, asm_map_t& asm_map, int cc, int version){
|
||||
// LLVM-IR -> PTX
|
||||
std::string ptx = drv::llir_to_ptx(llvm, cc, version);
|
||||
asm_map["ptx"] = ptx;
|
||||
// PTX -> Binary
|
||||
CUmodule mod = drv::ptx_to_cumodule(ptx, cc);
|
||||
// Handle to the kernel
|
||||
CUfunction fun;
|
||||
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
|
||||
// Dynamic shared memory
|
||||
int shared_optin;
|
||||
drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
|
||||
if(n_shared_bytes > 49152 && shared_optin > 49152){
|
||||
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
|
||||
int shared_total, shared_static;
|
||||
int n_spills, n_reg;
|
||||
drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev);
|
||||
drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
|
||||
drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
||||
drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
|
||||
drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
|
||||
}
|
||||
|
||||
// record asm
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
||||
}
|
||||
|
||||
std::tuple<uint64_t, uint64_t> hip_compile_llir(const std::string& name, llvm::Module* llvm, uint64_t dev, asm_map_t& asm_map){
|
||||
// LLVM-IR -> HSA-CO
|
||||
std::string path = drv::llir_to_amdgpu(llvm, "gfx908");
|
||||
// HSA-CO -> hipModule
|
||||
hipModule_t mod = drv::amdgpu_to_hipmodule(path);
|
||||
// Handle to the kernel
|
||||
hipFunction_t fun;
|
||||
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
||||
// record asm
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
||||
}
|
||||
|
||||
void init_triton_codegen(py::module &&m) {
|
||||
m.def(
|
||||
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages, bool force_nc_cache) {
|
||||
drv::module *mod;
|
||||
drv::kernel *ker;
|
||||
size_t shared_mem;
|
||||
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, force_nc_cache, mod, ker, shared_mem);
|
||||
std::stringstream ss;
|
||||
ir::print(ir, ss);
|
||||
return std::make_tuple(mod, ker, shared_mem, ss.str());
|
||||
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, bool force_nc_cache) {
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
// record asm as we generate
|
||||
asm_map_t asm_map;
|
||||
std::ostringstream ttir;
|
||||
ir::print(ir, ttir);
|
||||
asm_map["ttir"] = ttir.str();
|
||||
llvm::LLVMContext ctx;
|
||||
if(backend == CUDA){
|
||||
// device properties
|
||||
CUdevice dev = (CUdevice)device;
|
||||
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
|
||||
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
|
||||
size_t cc = major*10 + minor;
|
||||
int version;
|
||||
drv::dispatch::cuDriverGetVersion(&version);
|
||||
// Triton-IR -> NVPTX LLVM-IR
|
||||
triton::codegen::nvidia_cu_target target(cc);
|
||||
int n_shared_bytes;
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, force_nc_cache, n_shared_bytes);
|
||||
llvm::raw_string_ostream llir(asm_map["llir"]);
|
||||
llir << *llvm;
|
||||
llir.flush();
|
||||
// LLVM-IR -> Bin
|
||||
uint64_t mod, fun;
|
||||
std::tie(mod, fun) = cu_compile_llir(name, n_shared_bytes, &*llvm, device, asm_map, cc, version);
|
||||
return std::make_tuple(mod, fun, asm_map, n_shared_bytes);
|
||||
}
|
||||
if(backend == ROCM){
|
||||
// Triton-IR -> NVPTX LLVM-IR
|
||||
triton::codegen::amd_cl_target target;
|
||||
int n_shared_bytes;
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, force_nc_cache, n_shared_bytes);
|
||||
llvm::raw_string_ostream llir(asm_map["llir"]);
|
||||
llir << *llvm;
|
||||
llir.flush();
|
||||
// LLVM-IR -> Bin
|
||||
uint64_t mod, fun;
|
||||
std::tie(mod, fun) = hip_compile_llir(name, &*llvm, device, asm_map);
|
||||
return std::make_tuple(mod, fun, asm_map, n_shared_bytes);
|
||||
}
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
}
|
||||
@@ -302,7 +445,7 @@ void init_triton_ir(py::module &&m) {
|
||||
void init_triton(py::module &m) {
|
||||
py::module subm = m.def_submodule("triton");
|
||||
init_triton_codegen(std::move(subm.def_submodule("code_gen")));
|
||||
init_triton_driver(std::move(subm.def_submodule("driver")));
|
||||
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
||||
init_triton_ir(std::move(subm.def_submodule("ir")));
|
||||
init_triton_frontend(std::move(subm.def_submodule("frontend")));
|
||||
}
|
||||
|
@@ -34,6 +34,8 @@ def patch_kernel(template, to_replace):
|
||||
return kernel
|
||||
|
||||
|
||||
|
||||
|
||||
# generic test functions
|
||||
def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
|
||||
SIZE = 128
|
||||
@@ -425,7 +427,7 @@ def test_permute(dtype, shape, perm, device='cuda'):
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
# parse ptx to make sure ld/st are vectorized
|
||||
ptx = pgm.asm('ptx')
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
@@ -484,7 +486,7 @@ def test_dot(epilogue, device='cuda'):
|
||||
z_ref += z[0,:][None, :]
|
||||
z_ref = z_ref.to(torch.float16)
|
||||
# compare
|
||||
ptx = pgm.asm('ptx')
|
||||
ptx = pgm.asm['ptx']
|
||||
# print(ptx)
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
# make sure ld/st are vectorized
|
||||
@@ -511,3 +513,13 @@ def test_dot(epilogue, device='cuda'):
|
||||
# ---------------
|
||||
# test while
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test noop
|
||||
#----------------
|
||||
def test_noop(device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(**meta):
|
||||
pass
|
||||
x = triton.testing.random((1,), dtype=torch.int32, device=device)
|
||||
kernel[(1, )](x)
|
@@ -411,9 +411,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
|
||||
class Binary:
|
||||
def __init__(self, module, kernel, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm):
|
||||
def __init__(self, backend, module, kernel, asm, num_warps, num_stages, force_nc_cache, shared_mem):
|
||||
# cache ir asm
|
||||
self.ir_asm = ir_asm
|
||||
self.asm = asm
|
||||
self.module = module
|
||||
self.kernel = kernel
|
||||
self.shared_mem = shared_mem
|
||||
@@ -421,29 +421,13 @@ class Binary:
|
||||
self.num_stages = num_stages
|
||||
self.force_nc_cache = force_nc_cache
|
||||
self.sass = None
|
||||
|
||||
def asm(self, mode):
|
||||
if mode == 'ttir':
|
||||
return self.ir_asm
|
||||
if mode == 'ptx':
|
||||
return self.module.ptx()
|
||||
if mode == 'sass':
|
||||
if self.sass is None:
|
||||
cubin = self.module.cubin()
|
||||
# get a temporary file name
|
||||
fd, path = tempfile.mkstemp(suffix='.cubin')
|
||||
f = open(path, 'wb')
|
||||
f.write(cubin)
|
||||
f.close()
|
||||
# extract SASS from cubin
|
||||
self.sass = extract(path, None)
|
||||
return self.sass
|
||||
if mode == 'llir':
|
||||
return self.module.llir()
|
||||
raise ValueError('Unsupported mode ' + mode)
|
||||
self.backend = backend
|
||||
|
||||
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
||||
stream.enqueue(self.kernel, grid_0, grid_1, grid_2, self.num_warps * 32, 1, 1, args, self.shared_mem)
|
||||
_triton.runtime.enqueue(self.backend, stream, self.kernel,
|
||||
grid_0, grid_1, grid_2,
|
||||
self.num_warps * 32, 1, 1,
|
||||
args, self.shared_mem)
|
||||
|
||||
|
||||
class CompilationError(Exception):
|
||||
@@ -548,10 +532,15 @@ class Kernel:
|
||||
raise e
|
||||
raise CompilationError(self.fn.src, node, e)
|
||||
# Compile to machine code
|
||||
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, device, num_warps, num_stages, force_nc_cache)
|
||||
if shared_mem > device.max_shared_memory():
|
||||
raise OutOfResources(shared_mem, device.max_shared_memory(), "shared memory")
|
||||
return Binary(mod, ker, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm)
|
||||
if torch.version.hip is None:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
else:
|
||||
backend = _triton.runtime.backend.ROCM
|
||||
mod, ker, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache)
|
||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||
if shared_mem > max_shared_memory:
|
||||
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||
return Binary(backend, mod, ker, asm, num_warps, num_stages, force_nc_cache, shared_mem)
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
|
||||
# device inference
|
||||
@@ -571,19 +560,20 @@ class Kernel:
|
||||
" Only CUDA is supported at the moment")
|
||||
|
||||
device = torch.device('cuda', torch.cuda.current_device())
|
||||
tt_device = _triton.driver.cu_device(device.index, False)
|
||||
if len(set(device_ids)) != 1 or device_ids[0] != device.index:
|
||||
device_ty = device.type
|
||||
device_idx = device.index
|
||||
if len(set(device_ids)) != 1 or device_ids[0] != device_idx:
|
||||
# try to enable P2P communication
|
||||
for arg_idx, dst_idx in zip(tensor_idxs, device_ids):
|
||||
if dst_idx != device.index:
|
||||
if dst_idx != device_idx:
|
||||
try:
|
||||
tt_device.enable_peer_access(wargs[arg_idx].data_ptr())
|
||||
_triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr())
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}"
|
||||
.format(device.index, dst_idx, str(e)))
|
||||
.format(device_idx, dst_idx, str(e)))
|
||||
|
||||
# enqueue kernel on the current device
|
||||
torch.cuda.set_device(device.index)
|
||||
torch.cuda.set_device(device_idx)
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
|
||||
@@ -594,12 +584,12 @@ class Kernel:
|
||||
attr_key = frozenset(attributes.items())
|
||||
meta_key = frozenset(meta.items())
|
||||
const_key = frozenset(constants.items())
|
||||
key = (device.type, device.index, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
|
||||
key = (device_ty, device_idx, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
|
||||
cache = self.fn.cache
|
||||
if key not in cache:
|
||||
# compile and cache configuration if necessary
|
||||
cache[key] = self._compile(
|
||||
*wargs, device=tt_device, attributes=attributes,
|
||||
*wargs, device=device_idx, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
||||
constants=constants, **meta
|
||||
)
|
||||
@@ -608,8 +598,7 @@ class Kernel:
|
||||
params = struct.pack(fmt, *args)
|
||||
# enqueue cached function into stream
|
||||
binary = cache[key]
|
||||
cu_stream = torch.cuda.current_stream(device.index).cuda_stream
|
||||
stream = _triton.driver.cu_stream(cu_stream, False)
|
||||
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
||||
grid = grid(meta) if hasattr(grid, '__call__') else grid
|
||||
binary(stream, params, *grid)
|
||||
return binary
|
||||
|
@@ -64,7 +64,7 @@ def add(x: torch.Tensor, y: torch.Tensor):
|
||||
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
||||
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
|
||||
pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
# running asynchronously at this point.
|
||||
return output
|
||||
@@ -85,6 +85,7 @@ print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
||||
exit()
|
||||
|
||||
# %%
|
||||
# Seems like we're good to go!
|
||||
|
Reference in New Issue
Block a user