diff --git a/python/src/triton.cc b/python/src/triton.cc index 4648eee8b..da8b33227 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -607,11 +607,11 @@ void init_triton_codegen(py::module &&m) { // set dynamic shared memory if necessary std::cout << "\t" << "// set dynamic shared memory if necessary" << std::endl; int shared_optin; - drv::dispatch::hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device); + // drv::dispatch::hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device); if(n_shared_bytes > 49152 && shared_optin > 49152){ // drv::dispatch::hipFuncSetCacheConfig(fun, hipFuncCachePreferShared); int shared_total, shared_static; - drv::dispatch::hipDeviceGetAttribute(&shared_total, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device); + // drv::dispatch::hipDeviceGetAttribute(&shared_total, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device); // drv::dispatch::hipFuncGetAttributes(&attr, fun); shared_total = attr.sharedSizeBytes; // drv::dispatch::hipFuncSetAttribute(fun, hipFuncAttributeMaxDynamicSharedMemorySize, shared_optin - shared_static); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 0badd9007..b2d450e2f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -104,9 +104,13 @@ def check_type_supported(dtype): ''' skip test if dtype is not supported on the current device ''' - cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) - if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): - pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") + if torch.version.hip is not None: + if dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16: + pytest.skip("bfloat16 is not supported on AMDGPU") + else: + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): + pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") @pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes]) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 96ef21d5f..e39eb053a 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1024,7 +1024,7 @@ static inline hipDeviceptr_t getPointer(PyObject *obj, int idx) {{ static PyObject* launch(PyObject* self, PyObject* args) {{ - printf("launch(PyObject* self, PyObject* args)"); + // printf("launch(PyObject* self, PyObject* args)"); int gridX, gridY, gridZ; uint64_t _stream; uint64_t _function;