add fixes
This commit is contained in:
@@ -131,7 +131,6 @@ public:
|
|||||||
case FP8TyID: return "fp8";
|
case FP8TyID: return "fp8";
|
||||||
case BF16TyID: return "bf16";
|
case BF16TyID: return "bf16";
|
||||||
case FP16TyID: return "f16";
|
case FP16TyID: return "f16";
|
||||||
case BF16TyID: return "bf16";
|
|
||||||
case FP32TyID: return "f32";
|
case FP32TyID: return "f32";
|
||||||
case FP64TyID: return "f64";
|
case FP64TyID: return "f64";
|
||||||
case LabelTyID: return "label";
|
case LabelTyID: return "label";
|
||||||
|
@@ -3320,13 +3320,16 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
|||||||
default: throw std::runtime_error("unreachable");
|
default: throw std::runtime_error("unreachable");
|
||||||
}
|
}
|
||||||
ir::value *arg = x->get_operand(0);
|
ir::value *arg = x->get_operand(0);
|
||||||
if(arg->get_type()->get_tile_rank() == 1)
|
#ifdef USE_ROCM
|
||||||
visit_reduce1d_inst(x, do_acc, neutral);
|
visit_reducend_inst(x, do_acc, neutral);
|
||||||
else
|
#else
|
||||||
bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x);
|
bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x);
|
||||||
bool is_a100_mma = layouts_->is_a100_mma(x);
|
bool is_a100_mma = layouts_->is_a100_mma(x);
|
||||||
if (is_coalesced_scanline || is_a100_mma)
|
if (is_coalesced_scanline || is_a100_mma)
|
||||||
visit_reducend_inst_fast(x, do_acc, neutral);
|
visit_reducend_inst_fast(x, do_acc, neutral);
|
||||||
|
else
|
||||||
|
visit_reducend_inst(x, do_acc, neutral);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@@ -13,6 +13,7 @@ from typing import NamedTuple
|
|||||||
|
|
||||||
from setuptools import Extension, setup
|
from setuptools import Extension, setup
|
||||||
from setuptools.command.build_ext import build_ext
|
from setuptools.command.build_ext import build_ext
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
|
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
|
||||||
@@ -32,7 +33,8 @@ def get_build_type():
|
|||||||
def use_system_llvm():
|
def use_system_llvm():
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
return True
|
return True
|
||||||
versions = ['-11.0', '-11', '-11-64']
|
# versions = ['-11.0', '-11', '-11-64']
|
||||||
|
versions = ['-13.0', '-13', '-13-64']
|
||||||
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
||||||
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
||||||
return any(p is not None for p in paths)
|
return any(p is not None for p in paths)
|
||||||
@@ -53,7 +55,7 @@ def get_thirdparty_packages(triton_cache_path):
|
|||||||
if not use_system_llvm():
|
if not use_system_llvm():
|
||||||
# donwload LLVM if no suitable system LLVM is installed
|
# donwload LLVM if no suitable system LLVM is installed
|
||||||
packages.append(
|
packages.append(
|
||||||
Package("llvm", "clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR")
|
Package("llvm", "clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-13.0.0/clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR")
|
||||||
)
|
)
|
||||||
|
|
||||||
thirdparty_cmake_args = []
|
thirdparty_cmake_args = []
|
||||||
|
Reference in New Issue
Block a user