Merge select commits from master branch into triton-mlir (#799)

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: vesuppi <zt9465@gmail.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: daadaada <dyanab@connect.ust.hk>
Co-authored-by: Anton Kostin <masguit42@users.noreply.github.com>
Co-authored-by: Yunxing Dai <nov503@gmail.com>
Co-authored-by: Shintaro Iwasaki <shintaro.iwasaki.work@gmail.com>
This commit is contained in:
Philippe Tillet
2022-10-24 14:52:37 -07:00
committed by GitHub
parent 877844de4f
commit fcb228d1d4
9 changed files with 219 additions and 38 deletions

40
.github/workflows/wheels.yml vendored Normal file
View File

@@ -0,0 +1,40 @@
name: Wheels
on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *"
jobs:
Build-Wheels:
runs-on: [self-hosted, V100]
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Patch setup.py
run: |
#sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d' --format="%cd")
sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py
echo "" >> python/setup.cfg
echo "[build_ext]" >> python/setup.cfg
echo "base-dir=/project" >> python/setup.cfg
- name: Build wheels
run: |
export CIBW_MANYLINUX_X86_64_IMAGE="manylinux2014"
export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="manylinux2014"
export CIBW_BEFORE_BUILD="pip install cmake;\
yum install -y llvm11 llvm11-devel llvm11-static llvm11-libs zlib-devel;"
export CIBW_SKIP="{cp,pp}35-*"
export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64"
python3 -m cibuildwheel python --output-dir wheelhouse
- name: Upload wheels to PyPI
run: |
python3 -m twine upload wheelhouse/* --skip-existing

View File

@@ -1,6 +1,6 @@
/*
* Copyright 2018-2020 Philippe Tillet
* Copyright 2020-2021 OpenAI
* Copyright 2020-2022 OpenAI
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
@@ -20,4 +20,4 @@
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
*/

View File

@@ -132,6 +132,9 @@ class dtype:
def is_int_signed(self):
return self.name in dtype.SINT_TYPES
def is_int_unsigned(self):
return self.name in dtype.UINT_TYPES
def is_int(self):
return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES
@@ -460,6 +463,11 @@ class tensor:
other = _to_tensor(other, _builder)
return semantic.floordiv(self, other, _builder)
@builtin
def __rfloordiv__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.floordiv(other, self, _builder)
@builtin
def __mod__(self, other, _builder=None):
other = _to_tensor(other, _builder)
@@ -1041,21 +1049,35 @@ def debug_barrier(_builder=None):
@builtin
def multiple_of(input, value, _builder=None):
def multiple_of(input, values, _builder=None):
"""
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
"""
value = _constexpr_to_value(value)
return semantic.multiple_of(input, value)
if isinstance(values, constexpr):
values = [values]
for i, d in enumerate(values):
if not isinstance(d, constexpr):
raise TypeError(f"values element {i} must have type `constexpr`")
if not isinstance(d.value, int):
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
values = [x.value for x in values]
return semantic.multiple_of(input, values)
@builtin
def max_contiguous(input, value, _builder=None):
def max_contiguous(input, values, _builder=None):
"""
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
"""
value = _constexpr_to_value(value)
return semantic.max_contiguous(input, value)
if isinstance(values, constexpr):
values = [values]
for i, d in enumerate(values):
if not isinstance(d, constexpr):
raise TypeError(f"values element {i} must have type `constexpr`")
if not isinstance(d.value, int):
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
values = [x.value for x in values]
return semantic.max_contiguous(input, values)
# -----------------------

View File

@@ -91,9 +91,10 @@ def uint32_to_uniform_float(x):
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
"""
x = x.to(tl.int32, bitcast=True)
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
# maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
scale = 4.6566127342e-10
x = tl.where(x < 0, -x - 1, x)
return x * max
return x * scale
@triton.jit

View File

@@ -58,14 +58,22 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t
return tl.float32
# 3 ) if one operand is half, the other is implicitly converted to half
# unless we're doing / or %, which do not exist natively in PTX for fp16.
# Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
if a_ty.is_fp16() or b_ty.is_fp16():
if div_or_mod:
return tl.float32
else:
return tl.float16
# 4) return bf16 only if both operands are of bf16
if a_ty.is_bf16() or b_ty.is_bf16():
if div_or_mod:
return tl.float32
if a_ty.is_bf16() and b_ty.is_bf16():
return tl.bfloat16
return tl.float32
if not a_ty.is_int() or not b_ty.is_int():
assert False
# 4 ) both operands are integer and undergo
# 5 ) both operands are integer and undergo
# integer promotion
if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;"
@@ -452,6 +460,9 @@ def not_equal(input: tl.tensor,
def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
if not isinstance(start, int) or not isinstance(end, int):
raise ValueError("arange's arguments must be of type tl.constexpr")
shape = [end - start]
ret_ty = tl.block_type(tl.int32, shape)
return tl.tensor(builder.create_make_range(start, end), ret_ty)
@@ -502,6 +513,11 @@ def broadcast_impl_shape(input: tl.tensor,
raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
if shape == src_shape:
return input
for i in range(len(src_shape)):
if shape[i] != src_shape[i] and src_shape[i] != 1:
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
f" must match the existing size ({src_shape[1]}) at non-singleton dimension"
f" {i}: {src_shape}, {shape}")
ret_ty = tl.block_type(input.type.scalar, shape)
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
@@ -598,7 +614,13 @@ def cast(input: tl.tensor,
return input
src_sca_ty = src_ty.scalar
dst_sca_ty = dst_ty.scalar
# fp8 <=> bf16/fp16
if (src_sca_ty.is_bf16() or src_sca_ty.is_fp16()) and dst_sca_ty.is_fp8():
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)),
dst_ty)
if src_sca_ty.is_fp8() and (dst_sca_ty.is_bf16() or dst_sca_ty.is_fp16()):
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)),
dst_ty)
# bf16 <=> (not fp32)
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
@@ -783,16 +805,25 @@ def atomic_cas(ptr: tl.tensor,
cmp: tl.tensor,
val: tl.tensor,
builder: ir.builder) -> tl.tensor:
# TODO: type checking
element_ty = ptr.type.scalar.element_ty
if element_ty.primitive_bitwidth not in [16, 32, 64]:
raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type)
def atom_red_typechecking_impl(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
op: str,
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
element_ty = ptr.type.scalar.element_ty
if element_ty is tl.float16 and op != 'add':
raise ValueError("atomic_" + op + " does not support fp16")
if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]:
raise ValueError("atomic_" + op + " does not support " + element_ty)
if ptr.type.is_block():
if mask:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
@@ -813,7 +844,7 @@ def atomic_max(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
sca_ty = val.type.scalar
# direct call to atomic_max for integers
if sca_ty.is_int():
@@ -845,7 +876,7 @@ def atomic_min(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
sca_ty = val.type.scalar
# direct call to atomic_min for integers
if sca_ty.is_int():
@@ -885,7 +916,7 @@ def atomic_add(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
sca_ty = val.type.scalar
op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type)
@@ -895,7 +926,7 @@ def atomic_and(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type)
@@ -903,7 +934,7 @@ def atomic_or(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type)
@@ -911,7 +942,7 @@ def atomic_xor(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type)
@@ -919,7 +950,7 @@ def atomic_xchg(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type)
# ===----------------------------------------------------------------------===//
@@ -961,16 +992,8 @@ def where(condition: tl.tensor,
x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder)
y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder)
# TODO: we need to check x's and y's shape?
x_ty = x.type.scalar
y_ty = y.type.scalar
ty = computation_type_impl(x_ty, y_ty, div_or_mod=False)
x = cast(x, ty, builder)
y = cast(y, ty, builder)
if x.type.is_block():
ret_ty = tl.block_type(ty, x.type.shape)
else:
ret_ty = ty
x, y = binary_op_type_checking_impl(x, y, builder, True, True)
ret_ty = x.type
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
@@ -987,6 +1010,21 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32:
input = cast(input, tl.int32, builder)
# hardware doesn't support FMAX, FMIN, CMP for bfloat16
if scalar_ty is tl.bfloat16:
input = cast(input, tl.float32, builder)
# choose the right unsigned operation
if scalar_ty.is_int_unsigned():
int_op_to_unit = {
ir.REDUCE_OP.MIN: ir.REDUCE_OP.UMIN,
ir.REDUCE_OP.MAX: ir.REDUCE_OP.UMAX,
ir.REDUCE_OP.ARGMIN: ir.REDUCE_OP.ARGUMIN,
ir.REDUCE_OP.ARGMAX: ir.REDUCE_OP.ARGUMAX,
}
if INT_OP in int_op_to_unit:
INT_OP = int_op_to_unit[INT_OP]
# get result type
shape = input.type.shape
ret_shape = []
@@ -1056,13 +1094,17 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor:
##
def multiple_of(x: tl.tensor, value: int) -> tl.tensor:
x.handle.multiple_of(value)
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
raise ValueError("Shape of input to multiple_of does not match the length of values")
x.handle.multiple_of(values)
return x
def max_contiguous(x: tl.tensor, value: int) -> tl.tensor:
x.handle.max_contiguous(value)
def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
raise ValueError("Shape of input to max_contiguous does not match the length of values")
x.handle.max_contiguous(values)
return x

View File

@@ -65,7 +65,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
# write result in-place in PROBS
dout = tl.load(DPROBS + row)
din = (probs - delta) * dout
tl.store(PROBS, din.to(tl.float16), mask=cols < N)
tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N)
class _cross_entropy(torch.autograd.Function):

View File

@@ -26,7 +26,7 @@ def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
def get_tflops(backend, device, num_ctas, num_warps, dtype):
cc = _triton.runtime.cc(backend, device)
if cc < 80 and dtype == torch.float32:
return get_simd_tflops()
return get_simd_tflops(backend, device, num_ctas, num_warps, dtype)
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)

View File

@@ -1,6 +1,8 @@
import functools
import os
import subprocess
import sys
from contextlib import contextmanager
import torch
@@ -358,6 +360,80 @@ def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clo
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
return tflops
# create decorator that wraps test function into
# a cuda-memcheck system call
def cuda_memcheck(**target_kwargs):
def decorator(test_fn):
@functools.wraps(test_fn)
def wrapper(*args, **kwargs):
import psutil
ppid_name = psutil.Process(os.getppid()).name()
run_cuda_memcheck = target_kwargs.items() <= kwargs.items()
if run_cuda_memcheck and ppid_name != "cuda-memcheck":
path = os.path.realpath(test_fn.__globals__["__file__"])
# get path of current file
env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"}
assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture"
test_id = kwargs['request'].node.callspec.id
cmd = f"{path}::{test_fn.__name__}[{test_id}]"
out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env)
assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed"
assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
else:
test_fn(*args, **kwargs)
return wrapper
return decorator
def nvsmi_attr(attrs):
attrs = ",".join(attrs)
cmd = [
"nvidia-smi",
"-i",
"0",
"--query-gpu=" + attrs,
"--format=csv,noheader,nounits",
]
out = subprocess.check_output(cmd)
ret = out.decode(sys.stdout.encoding).split(",")
ret = [int(x) for x in ret]
return ret
@contextmanager
def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
try:
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
subprocess.check_output(
[
"nvidia-smi",
"-i",
"0",
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
]
)
subprocess.check_output(
[
"nvidia-smi",
"-i",
"0",
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
]
)
cur_sm_clock = nvsmi_attr(["clocks.current.sm"])[0]
cur_mem_clock = nvsmi_attr(["clocks.current.memory"])[0]
assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz"
tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock
gbps = 640 * 2 * ref_mem_clock * 1e-3
yield tflops, gbps
finally:
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"])
subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"])
subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"])
def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None):
if not backend:

View File

@@ -76,8 +76,8 @@ You will specifically learn about:
#
# .. code-block:: python
#
# pa += BLOCK_SIZE_K * stride_ak;
# pb += BLOCK_SIZE_K * stride_bk;
# a_ptrs += BLOCK_SIZE_K * stride_ak;
# b_ptrs += BLOCK_SIZE_K * stride_bk;
#
#
# L2 Cache Optimizations