[CODEGEN] Add cache modifier to tl.load (#351)
* Add cache modifier to tl.load * Add comment to cache_modifier * Remove force_nc_cache * Update test
This commit is contained in:
@@ -203,7 +203,7 @@ std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_
|
||||
// CUDA
|
||||
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir,
|
||||
uint64_t device, int num_warps, int num_stages,
|
||||
bool force_nc_cache, asm_map_t &asm_map){
|
||||
asm_map_t &asm_map){
|
||||
llvm::LLVMContext ctx;
|
||||
// device properties
|
||||
CUdevice dev = (CUdevice)device;
|
||||
@@ -215,7 +215,7 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
||||
// Triton-IR -> NVPTX LLVM-IR
|
||||
triton::codegen::nvidia_cu_target target(cc);
|
||||
int n_shared_bytes;
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, force_nc_cache, n_shared_bytes);
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
|
||||
std::string tmp;
|
||||
llvm::raw_string_ostream llir(tmp);
|
||||
llir << *llvm;
|
||||
@@ -236,12 +236,12 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
||||
// HIP
|
||||
std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name, ir::module &ir,
|
||||
uint64_t device, int num_warps, int num_stages,
|
||||
bool force_nc_cache, asm_map_t &asm_map){
|
||||
asm_map_t &asm_map){
|
||||
llvm::LLVMContext ctx;
|
||||
// Triton-IR -> NVPTX LLVM-IR
|
||||
triton::codegen::amd_cl_target target;
|
||||
int n_shared_bytes;
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, force_nc_cache, n_shared_bytes);
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, n_shared_bytes);
|
||||
std::string tmp;
|
||||
llvm::raw_string_ostream llir(tmp);
|
||||
llir << *llvm;
|
||||
@@ -255,7 +255,7 @@ std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name
|
||||
|
||||
void init_triton_codegen(py::module &&m) {
|
||||
m.def(
|
||||
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, bool force_nc_cache) {
|
||||
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages) {
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
// record asm as we generate
|
||||
asm_map_t asm_map;
|
||||
@@ -264,9 +264,9 @@ void init_triton_codegen(py::module &&m) {
|
||||
asm_map["ttir"] = py::cast(ttir.str());
|
||||
llvm::LLVMContext ctx;
|
||||
if(backend == CUDA)
|
||||
return cu_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map);
|
||||
return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
|
||||
if(backend == ROCM)
|
||||
return hip_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map);
|
||||
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
|
||||
}, py::return_value_policy::take_ownership);
|
||||
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||
if(backend == CUDA)
|
||||
|
@@ -599,6 +599,30 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
reference_out =torch.matmul(in1, in2)
|
||||
triton.testing.allclose(out, reference_out)
|
||||
|
||||
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
||||
def test_load_cache_modifier(cache):
|
||||
src = torch.empty(128, device='cuda')
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst, src, **meta):
|
||||
offsets = tl.arange(0, 128)
|
||||
x = tl.load(src+offsets, cache_modifier=meta['CACHE'])
|
||||
tl.store(dst+offsets, x)
|
||||
|
||||
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||
ptx = pgm.asm['ptx']
|
||||
|
||||
if cache == '':
|
||||
assert 'ld.global.ca' not in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if cache == '.cg':
|
||||
assert 'ld.global.cg' in ptx
|
||||
assert 'ld.global.ca' not in ptx
|
||||
if cache == '.ca':
|
||||
assert 'ld.global.ca' in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
|
||||
# ---------------
|
||||
# test store
|
||||
# ---------------
|
||||
|
@@ -537,7 +537,7 @@ class Kernel:
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta):
|
||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
# get just-in-time proto-type of kernel
|
||||
@@ -560,13 +560,13 @@ class Kernel:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
else:
|
||||
backend = _triton.runtime.backend.ROCM
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache)
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages)
|
||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||
if shared_mem > max_shared_memory:
|
||||
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||
return Binary(backend, name, asm, shared_mem, num_warps)
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
|
||||
# device inference
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
if len(tensor_idxs) == 0:
|
||||
@@ -643,7 +643,7 @@ class Kernel:
|
||||
if binary is None:
|
||||
binary = self._compile(
|
||||
*wargs, device=device_idx, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
||||
num_warps=num_warps, num_stages=num_stages,
|
||||
constants=constants, **meta
|
||||
)
|
||||
if bin_cache_path:
|
||||
|
@@ -387,7 +387,7 @@ def dot(input, other, _builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def load(pointer, mask=None, other=None, _builder=None):
|
||||
def load(pointer, mask=None, other=None, cache_modifier="", _builder=None):
|
||||
"""
|
||||
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
|
||||
|
||||
@@ -401,8 +401,10 @@ def load(pointer, mask=None, other=None, _builder=None):
|
||||
:type mask: Block of triton.int1, optional
|
||||
:param other: if mask[idx] is false, return other[idx]
|
||||
:type other: Block, optional
|
||||
:param cache_modifier: changes cache option in nvidia ptx
|
||||
'type cache_modifier: str, optional
|
||||
"""
|
||||
return frontend.load(pointer, mask, other, _builder)
|
||||
return frontend.load(pointer, mask, other, cache_modifier, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
|
Reference in New Issue
Block a user