diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index 55543ad31..2c9d25294 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -131,7 +131,6 @@ public: case FP8TyID: return "fp8"; case BF16TyID: return "bf16"; case FP16TyID: return "f16"; - case BF16TyID: return "bf16"; case FP32TyID: return "f32"; case FP64TyID: return "f64"; case LabelTyID: return "label"; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 767c05130..7a72df924 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -3320,13 +3320,16 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { default: throw std::runtime_error("unreachable"); } ir::value *arg = x->get_operand(0); - if(arg->get_type()->get_tile_rank() == 1) - visit_reduce1d_inst(x, do_acc, neutral); - else - bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x); - bool is_a100_mma = layouts_->is_a100_mma(x); - if (is_coalesced_scanline || is_a100_mma) +#ifdef USE_ROCM + visit_reducend_inst(x, do_acc, neutral); +#else + bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x); + bool is_a100_mma = layouts_->is_a100_mma(x); + if (is_coalesced_scanline || is_a100_mma) visit_reducend_inst_fast(x, do_acc, neutral); + else + visit_reducend_inst(x, do_acc, neutral); +#endif } /** diff --git a/python/setup.py b/python/setup.py index 817e2e091..e8d549a04 100644 --- a/python/setup.py +++ b/python/setup.py @@ -13,6 +13,7 @@ from typing import NamedTuple from setuptools import Extension, setup from setuptools.command.build_ext import build_ext +import torch # Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py @@ -32,7 +33,8 @@ def get_build_type(): def use_system_llvm(): if platform.system() == "Windows": return True - versions = ['-11.0', '-11', '-11-64'] + # versions = ['-11.0', '-11', '-11-64'] + versions = ['-13.0', '-13', '-13-64'] supported = ['llvm-config{v}'.format(v=v) for v in versions] paths = [distutils.spawn.find_executable(cfg) for cfg in supported] return any(p is not None for p in paths) @@ -53,7 +55,7 @@ def get_thirdparty_packages(triton_cache_path): if not use_system_llvm(): # donwload LLVM if no suitable system LLVM is installed packages.append( - Package("llvm", "clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR") + Package("llvm", "clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-13.0.0/clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR") ) thirdparty_cmake_args = []