[FRONTEND] Alignment fix-up (#428)

This commit is contained in:
Philippe Tillet
2022-01-11 23:11:58 -08:00
committed by GitHub
parent bbc78f6516
commit 4c94359199
4 changed files with 5386 additions and 871 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -101,12 +101,12 @@ void hip_enqueue(uint64_t stream, uint64_t kernel,
}
std::string pow2_divisor(long N){
if(N % 16 == 0) return "16";
if(N % 8 == 0) return "8";
if(N % 4 == 0) return "4";
if(N % 2 == 0) return "2";
return "1";
long pow2_divisor(long N){
if(N % 16 == 0) return 16;
if(N % 8 == 0) return 8;
if(N % 4 == 0) return 4;
if(N % 2 == 0) return 2;
return 1;
}
// Returns something like "int16", whether dtype is a torch.dtype or
@@ -127,6 +127,14 @@ std::string dtype_cache_key_part(const py::object& dtype) {
}
}
size_t get_pointer_range_size(uint64_t addr){
if(addr == 0)
return 0;
size_t size;
drv::dispatch::cuPointerGetAttribute(&size, CU_POINTER_ATTRIBUTE_RANGE_SIZE, (CUdeviceptr)addr);
return size;
}
// Launch
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
@@ -187,7 +195,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
continue;
// values divisible by small powers of 2 are specialized
cache_key += "[multipleof(";
cache_key += pow2_divisor(value);
cache_key += std::to_string(pow2_divisor(value));
cache_key += ")]";
continue;
}
@@ -213,12 +221,15 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
py::object data_ptr = arg.attr("data_ptr")();
long value = data_ptr.cast<long>();
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
// copy param
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
// udpate cache key
cache_key += dtype_cache_key_part(arg.attr("dtype"));
cache_key += "*";
cache_key += "[multipleof(";
cache_key += pow2_divisor(value);
size_t range_size = get_pointer_range_size(value);
cache_key += std::to_string(std::min(pow2_divisor(value), pow2_divisor(range_size)));
cache_key += ")]";
continue;
}
@@ -268,6 +279,10 @@ void init_triton_runtime(py::module &&m) {
}
);
// get range size for the given pointer
m.def("get_pointer_range_size", &get_pointer_range_size);
// cache key
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,

View File

@@ -674,9 +674,17 @@ class Kernel:
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# attributes
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args)
if isinstance(a, int) and i not in self.fn.do_not_specialize}
attributes = dict()
for i, arg in enumerate(wargs):
if i in self.fn.do_not_specialize:
continue
if isinstance(arg, int):
attributes[i] = Kernel.pow2_divisor(arg)
elif i in tensor_idxs:
addr = arg.data_ptr()
range_size = _triton.runtime.get_pointer_range_size(addr)
attributes[i] = min(Kernel.pow2_divisor(addr),
Kernel.pow2_divisor(range_size))
# transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}

View File

@@ -165,6 +165,7 @@ class block:
self.numel = 1
for s in self.shape:
self.numel *= s
self.numel = constexpr(self.numel)
# Data-type wrapper
self.dtype = block._init_dtype(self.handle.type.scalar)
@@ -873,7 +874,7 @@ def ravel(x):
:param x: the input block
:type x: Block
"""
return triton.language.reshape(x, [x.type.numel])
return triton.language.reshape(x, [x.numel])
@triton.jit