diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml new file mode 100644 index 000000000..d627888c5 --- /dev/null +++ b/.github/workflows/wheels.yml @@ -0,0 +1,40 @@ +name: Wheels +on: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" + +jobs: + + Build-Wheels: + + runs-on: [self-hosted, V100] + + steps: + + - name: Checkout + uses: actions/checkout@v2 + + - name: Patch setup.py + run: | + #sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py + export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d' --format="%cd") + sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py + echo "" >> python/setup.cfg + echo "[build_ext]" >> python/setup.cfg + echo "base-dir=/project" >> python/setup.cfg + + - name: Build wheels + run: | + export CIBW_MANYLINUX_X86_64_IMAGE="manylinux2014" + export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="manylinux2014" + export CIBW_BEFORE_BUILD="pip install cmake;\ + yum install -y llvm11 llvm11-devel llvm11-static llvm11-libs zlib-devel;" + export CIBW_SKIP="{cp,pp}35-*" + export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64" + python3 -m cibuildwheel python --output-dir wheelhouse + + + - name: Upload wheels to PyPI + run: | + python3 -m twine upload wheelhouse/* --skip-existing \ No newline at end of file diff --git a/LICENSE b/LICENSE index a50165bb2..4d18c2bdf 100755 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ /* * Copyright 2018-2020 Philippe Tillet -* Copyright 2020-2021 OpenAI +* Copyright 2020-2022 OpenAI * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files @@ -20,4 +20,4 @@ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -*/ \ No newline at end of file +*/ diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 402c6cb86..c958fa7f1 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -132,6 +132,9 @@ class dtype: def is_int_signed(self): return self.name in dtype.SINT_TYPES + def is_int_unsigned(self): + return self.name in dtype.UINT_TYPES + def is_int(self): return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES @@ -460,6 +463,11 @@ class tensor: other = _to_tensor(other, _builder) return semantic.floordiv(self, other, _builder) + @builtin + def __rfloordiv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.floordiv(other, self, _builder) + @builtin def __mod__(self, other, _builder=None): other = _to_tensor(other, _builder) @@ -1041,21 +1049,35 @@ def debug_barrier(_builder=None): @builtin -def multiple_of(input, value, _builder=None): +def multiple_of(input, values, _builder=None): """ Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. """ - value = _constexpr_to_value(value) - return semantic.multiple_of(input, value) + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.multiple_of(input, values) @builtin -def max_contiguous(input, value, _builder=None): +def max_contiguous(input, values, _builder=None): """ Let the compiler knows that the `value` first values in :code:`input` are contiguous. """ - value = _constexpr_to_value(value) - return semantic.max_contiguous(input, value) + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_contiguous(input, values) # ----------------------- diff --git a/python/triton/language/random.py b/python/triton/language/random.py index c95eac9fc..67de92c43 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -91,9 +91,10 @@ def uint32_to_uniform_float(x): Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). """ x = x.to(tl.int32, bitcast=True) - max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647. + # maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + scale = 4.6566127342e-10 x = tl.where(x < 0, -x - 1, x) - return x * max + return x * scale @triton.jit diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index aaa19d7f9..619fad4ee 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -58,14 +58,22 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t return tl.float32 # 3 ) if one operand is half, the other is implicitly converted to half # unless we're doing / or %, which do not exist natively in PTX for fp16. + # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp if a_ty.is_fp16() or b_ty.is_fp16(): if div_or_mod: return tl.float32 else: return tl.float16 + # 4) return bf16 only if both operands are of bf16 + if a_ty.is_bf16() or b_ty.is_bf16(): + if div_or_mod: + return tl.float32 + if a_ty.is_bf16() and b_ty.is_bf16(): + return tl.bfloat16 + return tl.float32 if not a_ty.is_int() or not b_ty.is_int(): assert False - # 4 ) both operands are integer and undergo + # 5 ) both operands are integer and undergo # integer promotion if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;" @@ -452,6 +460,9 @@ def not_equal(input: tl.tensor, def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: + if not isinstance(start, int) or not isinstance(end, int): + raise ValueError("arange's arguments must be of type tl.constexpr") + shape = [end - start] ret_ty = tl.block_type(tl.int32, shape) return tl.tensor(builder.create_make_range(start, end), ret_ty) @@ -502,6 +513,11 @@ def broadcast_impl_shape(input: tl.tensor, raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") if shape == src_shape: return input + for i in range(len(src_shape)): + if shape[i] != src_shape[i] and src_shape[i] != 1: + raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" + f" must match the existing size ({src_shape[1]}) at non-singleton dimension" + f" {i}: {src_shape}, {shape}") ret_ty = tl.block_type(input.type.scalar, shape) return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) @@ -598,7 +614,13 @@ def cast(input: tl.tensor, return input src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar - + # fp8 <=> bf16/fp16 + if (src_sca_ty.is_bf16() or src_sca_ty.is_fp16()) and dst_sca_ty.is_fp8(): + return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), + dst_ty) + if src_sca_ty.is_fp8() and (dst_sca_ty.is_bf16() or dst_sca_ty.is_fp16()): + return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), + dst_ty) # bf16 <=> (not fp32) if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \ (dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()): @@ -783,16 +805,25 @@ def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, builder: ir.builder) -> tl.tensor: - # TODO: type checking + element_ty = ptr.type.scalar.element_ty + if element_ty.primitive_bitwidth not in [16, 32, 64]: + raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type) def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, + op: str, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: if not ptr.type.scalar.is_ptr(): raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + + element_ty = ptr.type.scalar.element_ty + if element_ty is tl.float16 and op != 'add': + raise ValueError("atomic_" + op + " does not support fp16") + if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]: + raise ValueError("atomic_" + op + " does not support " + element_ty) if ptr.type.is_block(): if mask: mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) @@ -813,7 +844,7 @@ def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder) sca_ty = val.type.scalar # direct call to atomic_max for integers if sca_ty.is_int(): @@ -845,7 +876,7 @@ def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder) sca_ty = val.type.scalar # direct call to atomic_min for integers if sca_ty.is_int(): @@ -885,7 +916,7 @@ def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder) sca_ty = val.type.scalar op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type) @@ -895,7 +926,7 @@ def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type) @@ -903,7 +934,7 @@ def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type) @@ -911,7 +942,7 @@ def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type) @@ -919,7 +950,7 @@ def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type) # ===----------------------------------------------------------------------===// @@ -961,16 +992,8 @@ def where(condition: tl.tensor, x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder) y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder) - # TODO: we need to check x's and y's shape? - x_ty = x.type.scalar - y_ty = y.type.scalar - ty = computation_type_impl(x_ty, y_ty, div_or_mod=False) - x = cast(x, ty, builder) - y = cast(y, ty, builder) - if x.type.is_block(): - ret_ty = tl.block_type(ty, x.type.shape) - else: - ret_ty = ty + x, y = binary_op_type_checking_impl(x, y, builder, True, True) + ret_ty = x.type return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) @@ -987,6 +1010,21 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32: input = cast(input, tl.int32, builder) + # hardware doesn't support FMAX, FMIN, CMP for bfloat16 + if scalar_ty is tl.bfloat16: + input = cast(input, tl.float32, builder) + + # choose the right unsigned operation + if scalar_ty.is_int_unsigned(): + int_op_to_unit = { + ir.REDUCE_OP.MIN: ir.REDUCE_OP.UMIN, + ir.REDUCE_OP.MAX: ir.REDUCE_OP.UMAX, + ir.REDUCE_OP.ARGMIN: ir.REDUCE_OP.ARGUMIN, + ir.REDUCE_OP.ARGMAX: ir.REDUCE_OP.ARGUMAX, + } + if INT_OP in int_op_to_unit: + INT_OP = int_op_to_unit[INT_OP] + # get result type shape = input.type.shape ret_shape = [] @@ -1056,13 +1094,17 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor: ## -def multiple_of(x: tl.tensor, value: int) -> tl.tensor: - x.handle.multiple_of(value) +def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to multiple_of does not match the length of values") + x.handle.multiple_of(values) return x -def max_contiguous(x: tl.tensor, value: int) -> tl.tensor: - x.handle.max_contiguous(value) +def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_contiguous does not match the length of values") + x.handle.max_contiguous(values) return x diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index 910417d2c..63ce81074 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -65,7 +65,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): # write result in-place in PROBS dout = tl.load(DPROBS + row) din = (probs - delta) * dout - tl.store(PROBS, din.to(tl.float16), mask=cols < N) + tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N) class _cross_entropy(torch.autograd.Function): diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py index 9c10b88d8..004f236b9 100644 --- a/python/triton/ops/matmul_perf_model.py +++ b/python/triton/ops/matmul_perf_model.py @@ -26,7 +26,7 @@ def get_simd_tflops(backend, device, num_ctas, num_warps, dtype): def get_tflops(backend, device, num_ctas, num_warps, dtype): cc = _triton.runtime.cc(backend, device) if cc < 80 and dtype == torch.float32: - return get_simd_tflops() + return get_simd_tflops(backend, device, num_ctas, num_warps, dtype) return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype) diff --git a/python/triton/testing.py b/python/triton/testing.py index f42f38b9f..2c9ece2fe 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,6 +1,8 @@ +import functools import os import subprocess import sys +from contextlib import contextmanager import torch @@ -358,6 +360,80 @@ def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clo tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 return tflops +# create decorator that wraps test function into +# a cuda-memcheck system call + + +def cuda_memcheck(**target_kwargs): + def decorator(test_fn): + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + import psutil + ppid_name = psutil.Process(os.getppid()).name() + run_cuda_memcheck = target_kwargs.items() <= kwargs.items() + if run_cuda_memcheck and ppid_name != "cuda-memcheck": + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"} + assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture" + test_id = kwargs['request'].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) + assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed" + assert "ERROR SUMMARY: 0 errors" in str(out.stdout) + else: + test_fn(*args, **kwargs) + return wrapper + return decorator + + +def nvsmi_attr(attrs): + attrs = ",".join(attrs) + cmd = [ + "nvidia-smi", + "-i", + "0", + "--query-gpu=" + attrs, + "--format=csv,noheader,nounits", + ] + out = subprocess.check_output(cmd) + ret = out.decode(sys.stdout.encoding).split(",") + ret = [int(x) for x in ret] + return ret + + +@contextmanager +def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): + try: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) + subprocess.check_output( + [ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ] + ) + subprocess.check_output( + [ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ] + ) + cur_sm_clock = nvsmi_attr(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi_attr(["clocks.current.memory"])[0] + assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" + assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" + tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock + gbps = 640 * 2 * ref_mem_clock * 1e-3 + yield tflops, gbps + finally: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) + def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None): if not backend: diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index e323c1d21..f685a6059 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -76,8 +76,8 @@ You will specifically learn about: # # .. code-block:: python # -# pa += BLOCK_SIZE_K * stride_ak; -# pb += BLOCK_SIZE_K * stride_bk; +# a_ptrs += BLOCK_SIZE_K * stride_ak; +# b_ptrs += BLOCK_SIZE_K * stride_bk; # # # L2 Cache Optimizations