[CODEGEN] Add cache modifier to tl.load (#351)
* Add cache modifier to tl.load * Add comment to cache_modifier * Remove force_nc_cache * Update test
This commit is contained in:
@@ -599,6 +599,30 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
reference_out =torch.matmul(in1, in2)
|
||||
triton.testing.allclose(out, reference_out)
|
||||
|
||||
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
||||
def test_load_cache_modifier(cache):
|
||||
src = torch.empty(128, device='cuda')
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst, src, **meta):
|
||||
offsets = tl.arange(0, 128)
|
||||
x = tl.load(src+offsets, cache_modifier=meta['CACHE'])
|
||||
tl.store(dst+offsets, x)
|
||||
|
||||
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||
ptx = pgm.asm['ptx']
|
||||
|
||||
if cache == '':
|
||||
assert 'ld.global.ca' not in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if cache == '.cg':
|
||||
assert 'ld.global.cg' in ptx
|
||||
assert 'ld.global.ca' not in ptx
|
||||
if cache == '.ca':
|
||||
assert 'ld.global.ca' in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
|
||||
# ---------------
|
||||
# test store
|
||||
# ---------------
|
||||
|
Reference in New Issue
Block a user