[FRONTEND] Alignment fix-up (#428)
This commit is contained in:
5817
include/triton/external/CUDA/cuda.h
vendored
5817
include/triton/external/CUDA/cuda.h
vendored
File diff suppressed because it is too large
Load Diff
@@ -101,12 +101,12 @@ void hip_enqueue(uint64_t stream, uint64_t kernel,
|
||||
|
||||
}
|
||||
|
||||
std::string pow2_divisor(long N){
|
||||
if(N % 16 == 0) return "16";
|
||||
if(N % 8 == 0) return "8";
|
||||
if(N % 4 == 0) return "4";
|
||||
if(N % 2 == 0) return "2";
|
||||
return "1";
|
||||
long pow2_divisor(long N){
|
||||
if(N % 16 == 0) return 16;
|
||||
if(N % 8 == 0) return 8;
|
||||
if(N % 4 == 0) return 4;
|
||||
if(N % 2 == 0) return 2;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Returns something like "int16", whether dtype is a torch.dtype or
|
||||
@@ -127,6 +127,14 @@ std::string dtype_cache_key_part(const py::object& dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
size_t get_pointer_range_size(uint64_t addr){
|
||||
if(addr == 0)
|
||||
return 0;
|
||||
size_t size;
|
||||
drv::dispatch::cuPointerGetAttribute(&size, CU_POINTER_ATTRIBUTE_RANGE_SIZE, (CUdeviceptr)addr);
|
||||
return size;
|
||||
}
|
||||
|
||||
// Launch
|
||||
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
|
||||
@@ -187,7 +195,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
continue;
|
||||
// values divisible by small powers of 2 are specialized
|
||||
cache_key += "[multipleof(";
|
||||
cache_key += pow2_divisor(value);
|
||||
cache_key += std::to_string(pow2_divisor(value));
|
||||
cache_key += ")]";
|
||||
continue;
|
||||
}
|
||||
@@ -213,12 +221,15 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
py::object data_ptr = arg.attr("data_ptr")();
|
||||
long value = data_ptr.cast<long>();
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
// copy param
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
// udpate cache key
|
||||
cache_key += dtype_cache_key_part(arg.attr("dtype"));
|
||||
cache_key += "*";
|
||||
cache_key += "[multipleof(";
|
||||
cache_key += pow2_divisor(value);
|
||||
size_t range_size = get_pointer_range_size(value);
|
||||
cache_key += std::to_string(std::min(pow2_divisor(value), pow2_divisor(range_size)));
|
||||
cache_key += ")]";
|
||||
continue;
|
||||
}
|
||||
@@ -268,6 +279,10 @@ void init_triton_runtime(py::module &&m) {
|
||||
}
|
||||
);
|
||||
|
||||
// get range size for the given pointer
|
||||
m.def("get_pointer_range_size", &get_pointer_range_size);
|
||||
|
||||
|
||||
// cache key
|
||||
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
|
||||
|
@@ -674,9 +674,17 @@ class Kernel:
|
||||
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args)
|
||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||
attributes = dict()
|
||||
for i, arg in enumerate(wargs):
|
||||
if i in self.fn.do_not_specialize:
|
||||
continue
|
||||
if isinstance(arg, int):
|
||||
attributes[i] = Kernel.pow2_divisor(arg)
|
||||
elif i in tensor_idxs:
|
||||
addr = arg.data_ptr()
|
||||
range_size = _triton.runtime.get_pointer_range_size(addr)
|
||||
attributes[i] = min(Kernel.pow2_divisor(addr),
|
||||
Kernel.pow2_divisor(range_size))
|
||||
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
|
||||
|
@@ -165,6 +165,7 @@ class block:
|
||||
self.numel = 1
|
||||
for s in self.shape:
|
||||
self.numel *= s
|
||||
self.numel = constexpr(self.numel)
|
||||
# Data-type wrapper
|
||||
self.dtype = block._init_dtype(self.handle.type.scalar)
|
||||
|
||||
@@ -873,7 +874,7 @@ def ravel(x):
|
||||
:param x: the input block
|
||||
:type x: Block
|
||||
"""
|
||||
return triton.language.reshape(x, [x.type.numel])
|
||||
return triton.language.reshape(x, [x.numel])
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
Reference in New Issue
Block a user