Merge select commits from master
branch into triton-mlir
(#799)
Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: vesuppi <zt9465@gmail.com> Co-authored-by: Jason Ansel <jansel@jansel.net> Co-authored-by: daadaada <dyanab@connect.ust.hk> Co-authored-by: Anton Kostin <masguit42@users.noreply.github.com> Co-authored-by: Yunxing Dai <nov503@gmail.com> Co-authored-by: Shintaro Iwasaki <shintaro.iwasaki.work@gmail.com>
This commit is contained in:
40
.github/workflows/wheels.yml
vendored
Normal file
40
.github/workflows/wheels.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
name: Wheels
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
schedule:
|
||||||
|
- cron: "0 0 * * *"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
Build-Wheels:
|
||||||
|
|
||||||
|
runs-on: [self-hosted, V100]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Patch setup.py
|
||||||
|
run: |
|
||||||
|
#sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
|
||||||
|
export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d' --format="%cd")
|
||||||
|
sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py
|
||||||
|
echo "" >> python/setup.cfg
|
||||||
|
echo "[build_ext]" >> python/setup.cfg
|
||||||
|
echo "base-dir=/project" >> python/setup.cfg
|
||||||
|
|
||||||
|
- name: Build wheels
|
||||||
|
run: |
|
||||||
|
export CIBW_MANYLINUX_X86_64_IMAGE="manylinux2014"
|
||||||
|
export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="manylinux2014"
|
||||||
|
export CIBW_BEFORE_BUILD="pip install cmake;\
|
||||||
|
yum install -y llvm11 llvm11-devel llvm11-static llvm11-libs zlib-devel;"
|
||||||
|
export CIBW_SKIP="{cp,pp}35-*"
|
||||||
|
export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64"
|
||||||
|
python3 -m cibuildwheel python --output-dir wheelhouse
|
||||||
|
|
||||||
|
|
||||||
|
- name: Upload wheels to PyPI
|
||||||
|
run: |
|
||||||
|
python3 -m twine upload wheelhouse/* --skip-existing
|
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
* Copyright 2018-2020 Philippe Tillet
|
* Copyright 2018-2020 Philippe Tillet
|
||||||
* Copyright 2020-2021 OpenAI
|
* Copyright 2020-2022 OpenAI
|
||||||
*
|
*
|
||||||
* Permission is hereby granted, free of charge, to any person obtaining
|
* Permission is hereby granted, free of charge, to any person obtaining
|
||||||
* a copy of this software and associated documentation files
|
* a copy of this software and associated documentation files
|
||||||
|
@@ -132,6 +132,9 @@ class dtype:
|
|||||||
def is_int_signed(self):
|
def is_int_signed(self):
|
||||||
return self.name in dtype.SINT_TYPES
|
return self.name in dtype.SINT_TYPES
|
||||||
|
|
||||||
|
def is_int_unsigned(self):
|
||||||
|
return self.name in dtype.UINT_TYPES
|
||||||
|
|
||||||
def is_int(self):
|
def is_int(self):
|
||||||
return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES
|
return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES
|
||||||
|
|
||||||
@@ -460,6 +463,11 @@ class tensor:
|
|||||||
other = _to_tensor(other, _builder)
|
other = _to_tensor(other, _builder)
|
||||||
return semantic.floordiv(self, other, _builder)
|
return semantic.floordiv(self, other, _builder)
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def __rfloordiv__(self, other, _builder=None):
|
||||||
|
other = _to_tensor(other, _builder)
|
||||||
|
return semantic.floordiv(other, self, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __mod__(self, other, _builder=None):
|
def __mod__(self, other, _builder=None):
|
||||||
other = _to_tensor(other, _builder)
|
other = _to_tensor(other, _builder)
|
||||||
@@ -1041,21 +1049,35 @@ def debug_barrier(_builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def multiple_of(input, value, _builder=None):
|
def multiple_of(input, values, _builder=None):
|
||||||
"""
|
"""
|
||||||
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
|
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
|
||||||
"""
|
"""
|
||||||
value = _constexpr_to_value(value)
|
if isinstance(values, constexpr):
|
||||||
return semantic.multiple_of(input, value)
|
values = [values]
|
||||||
|
for i, d in enumerate(values):
|
||||||
|
if not isinstance(d, constexpr):
|
||||||
|
raise TypeError(f"values element {i} must have type `constexpr`")
|
||||||
|
if not isinstance(d.value, int):
|
||||||
|
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||||
|
values = [x.value for x in values]
|
||||||
|
return semantic.multiple_of(input, values)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def max_contiguous(input, value, _builder=None):
|
def max_contiguous(input, values, _builder=None):
|
||||||
"""
|
"""
|
||||||
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
|
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
|
||||||
"""
|
"""
|
||||||
value = _constexpr_to_value(value)
|
if isinstance(values, constexpr):
|
||||||
return semantic.max_contiguous(input, value)
|
values = [values]
|
||||||
|
for i, d in enumerate(values):
|
||||||
|
if not isinstance(d, constexpr):
|
||||||
|
raise TypeError(f"values element {i} must have type `constexpr`")
|
||||||
|
if not isinstance(d.value, int):
|
||||||
|
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||||
|
values = [x.value for x in values]
|
||||||
|
return semantic.max_contiguous(input, values)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
@@ -91,9 +91,10 @@ def uint32_to_uniform_float(x):
|
|||||||
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
||||||
"""
|
"""
|
||||||
x = x.to(tl.int32, bitcast=True)
|
x = x.to(tl.int32, bitcast=True)
|
||||||
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
|
# maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
|
||||||
|
scale = 4.6566127342e-10
|
||||||
x = tl.where(x < 0, -x - 1, x)
|
x = tl.where(x < 0, -x - 1, x)
|
||||||
return x * max
|
return x * scale
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
@@ -58,14 +58,22 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t
|
|||||||
return tl.float32
|
return tl.float32
|
||||||
# 3 ) if one operand is half, the other is implicitly converted to half
|
# 3 ) if one operand is half, the other is implicitly converted to half
|
||||||
# unless we're doing / or %, which do not exist natively in PTX for fp16.
|
# unless we're doing / or %, which do not exist natively in PTX for fp16.
|
||||||
|
# Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
|
||||||
if a_ty.is_fp16() or b_ty.is_fp16():
|
if a_ty.is_fp16() or b_ty.is_fp16():
|
||||||
if div_or_mod:
|
if div_or_mod:
|
||||||
return tl.float32
|
return tl.float32
|
||||||
else:
|
else:
|
||||||
return tl.float16
|
return tl.float16
|
||||||
|
# 4) return bf16 only if both operands are of bf16
|
||||||
|
if a_ty.is_bf16() or b_ty.is_bf16():
|
||||||
|
if div_or_mod:
|
||||||
|
return tl.float32
|
||||||
|
if a_ty.is_bf16() and b_ty.is_bf16():
|
||||||
|
return tl.bfloat16
|
||||||
|
return tl.float32
|
||||||
if not a_ty.is_int() or not b_ty.is_int():
|
if not a_ty.is_int() or not b_ty.is_int():
|
||||||
assert False
|
assert False
|
||||||
# 4 ) both operands are integer and undergo
|
# 5 ) both operands are integer and undergo
|
||||||
# integer promotion
|
# integer promotion
|
||||||
if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
|
if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
|
||||||
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;"
|
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;"
|
||||||
@@ -452,6 +460,9 @@ def not_equal(input: tl.tensor,
|
|||||||
|
|
||||||
|
|
||||||
def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
|
def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
|
||||||
|
if not isinstance(start, int) or not isinstance(end, int):
|
||||||
|
raise ValueError("arange's arguments must be of type tl.constexpr")
|
||||||
|
|
||||||
shape = [end - start]
|
shape = [end - start]
|
||||||
ret_ty = tl.block_type(tl.int32, shape)
|
ret_ty = tl.block_type(tl.int32, shape)
|
||||||
return tl.tensor(builder.create_make_range(start, end), ret_ty)
|
return tl.tensor(builder.create_make_range(start, end), ret_ty)
|
||||||
@@ -502,6 +513,11 @@ def broadcast_impl_shape(input: tl.tensor,
|
|||||||
raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
|
raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
|
||||||
if shape == src_shape:
|
if shape == src_shape:
|
||||||
return input
|
return input
|
||||||
|
for i in range(len(src_shape)):
|
||||||
|
if shape[i] != src_shape[i] and src_shape[i] != 1:
|
||||||
|
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
|
||||||
|
f" must match the existing size ({src_shape[1]}) at non-singleton dimension"
|
||||||
|
f" {i}: {src_shape}, {shape}")
|
||||||
ret_ty = tl.block_type(input.type.scalar, shape)
|
ret_ty = tl.block_type(input.type.scalar, shape)
|
||||||
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
|
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
|
||||||
|
|
||||||
@@ -598,7 +614,13 @@ def cast(input: tl.tensor,
|
|||||||
return input
|
return input
|
||||||
src_sca_ty = src_ty.scalar
|
src_sca_ty = src_ty.scalar
|
||||||
dst_sca_ty = dst_ty.scalar
|
dst_sca_ty = dst_ty.scalar
|
||||||
|
# fp8 <=> bf16/fp16
|
||||||
|
if (src_sca_ty.is_bf16() or src_sca_ty.is_fp16()) and dst_sca_ty.is_fp8():
|
||||||
|
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)),
|
||||||
|
dst_ty)
|
||||||
|
if src_sca_ty.is_fp8() and (dst_sca_ty.is_bf16() or dst_sca_ty.is_fp16()):
|
||||||
|
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)),
|
||||||
|
dst_ty)
|
||||||
# bf16 <=> (not fp32)
|
# bf16 <=> (not fp32)
|
||||||
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
|
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
|
||||||
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
|
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
|
||||||
@@ -783,16 +805,25 @@ def atomic_cas(ptr: tl.tensor,
|
|||||||
cmp: tl.tensor,
|
cmp: tl.tensor,
|
||||||
val: tl.tensor,
|
val: tl.tensor,
|
||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
# TODO: type checking
|
element_ty = ptr.type.scalar.element_ty
|
||||||
|
if element_ty.primitive_bitwidth not in [16, 32, 64]:
|
||||||
|
raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
|
||||||
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type)
|
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type)
|
||||||
|
|
||||||
|
|
||||||
def atom_red_typechecking_impl(ptr: tl.tensor,
|
def atom_red_typechecking_impl(ptr: tl.tensor,
|
||||||
val: tl.tensor,
|
val: tl.tensor,
|
||||||
mask: tl.tensor,
|
mask: tl.tensor,
|
||||||
|
op: str,
|
||||||
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
|
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
|
||||||
if not ptr.type.scalar.is_ptr():
|
if not ptr.type.scalar.is_ptr():
|
||||||
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
|
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
|
||||||
|
|
||||||
|
element_ty = ptr.type.scalar.element_ty
|
||||||
|
if element_ty is tl.float16 and op != 'add':
|
||||||
|
raise ValueError("atomic_" + op + " does not support fp16")
|
||||||
|
if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]:
|
||||||
|
raise ValueError("atomic_" + op + " does not support " + element_ty)
|
||||||
if ptr.type.is_block():
|
if ptr.type.is_block():
|
||||||
if mask:
|
if mask:
|
||||||
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
|
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
|
||||||
@@ -813,7 +844,7 @@ def atomic_max(ptr: tl.tensor,
|
|||||||
val: tl.tensor,
|
val: tl.tensor,
|
||||||
mask: tl.tensor,
|
mask: tl.tensor,
|
||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
|
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
|
||||||
sca_ty = val.type.scalar
|
sca_ty = val.type.scalar
|
||||||
# direct call to atomic_max for integers
|
# direct call to atomic_max for integers
|
||||||
if sca_ty.is_int():
|
if sca_ty.is_int():
|
||||||
@@ -845,7 +876,7 @@ def atomic_min(ptr: tl.tensor,
|
|||||||
val: tl.tensor,
|
val: tl.tensor,
|
||||||
mask: tl.tensor,
|
mask: tl.tensor,
|
||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
|
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
|
||||||
sca_ty = val.type.scalar
|
sca_ty = val.type.scalar
|
||||||
# direct call to atomic_min for integers
|
# direct call to atomic_min for integers
|
||||||
if sca_ty.is_int():
|
if sca_ty.is_int():
|
||||||
@@ -885,7 +916,7 @@ def atomic_add(ptr: tl.tensor,
|
|||||||
val: tl.tensor,
|
val: tl.tensor,
|
||||||
mask: tl.tensor,
|
mask: tl.tensor,
|
||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
|
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
|
||||||
sca_ty = val.type.scalar
|
sca_ty = val.type.scalar
|
||||||
op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
|
op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
|
||||||
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type)
|
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type)
|
||||||
@@ -895,7 +926,7 @@ def atomic_and(ptr: tl.tensor,
|
|||||||
val: tl.tensor,
|
val: tl.tensor,
|
||||||
mask: tl.tensor,
|
mask: tl.tensor,
|
||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
|
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
|
||||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type)
|
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type)
|
||||||
|
|
||||||
|
|
||||||
@@ -903,7 +934,7 @@ def atomic_or(ptr: tl.tensor,
|
|||||||
val: tl.tensor,
|
val: tl.tensor,
|
||||||
mask: tl.tensor,
|
mask: tl.tensor,
|
||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
|
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
|
||||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type)
|
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type)
|
||||||
|
|
||||||
|
|
||||||
@@ -911,7 +942,7 @@ def atomic_xor(ptr: tl.tensor,
|
|||||||
val: tl.tensor,
|
val: tl.tensor,
|
||||||
mask: tl.tensor,
|
mask: tl.tensor,
|
||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
|
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
|
||||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type)
|
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type)
|
||||||
|
|
||||||
|
|
||||||
@@ -919,7 +950,7 @@ def atomic_xchg(ptr: tl.tensor,
|
|||||||
val: tl.tensor,
|
val: tl.tensor,
|
||||||
mask: tl.tensor,
|
mask: tl.tensor,
|
||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder)
|
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
|
||||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type)
|
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type)
|
||||||
|
|
||||||
# ===----------------------------------------------------------------------===//
|
# ===----------------------------------------------------------------------===//
|
||||||
@@ -961,16 +992,8 @@ def where(condition: tl.tensor,
|
|||||||
x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder)
|
x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder)
|
||||||
y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder)
|
y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder)
|
||||||
|
|
||||||
# TODO: we need to check x's and y's shape?
|
x, y = binary_op_type_checking_impl(x, y, builder, True, True)
|
||||||
x_ty = x.type.scalar
|
ret_ty = x.type
|
||||||
y_ty = y.type.scalar
|
|
||||||
ty = computation_type_impl(x_ty, y_ty, div_or_mod=False)
|
|
||||||
x = cast(x, ty, builder)
|
|
||||||
y = cast(y, ty, builder)
|
|
||||||
if x.type.is_block():
|
|
||||||
ret_ty = tl.block_type(ty, x.type.shape)
|
|
||||||
else:
|
|
||||||
ret_ty = ty
|
|
||||||
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
|
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
|
||||||
|
|
||||||
|
|
||||||
@@ -987,6 +1010,21 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
|||||||
if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32:
|
if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32:
|
||||||
input = cast(input, tl.int32, builder)
|
input = cast(input, tl.int32, builder)
|
||||||
|
|
||||||
|
# hardware doesn't support FMAX, FMIN, CMP for bfloat16
|
||||||
|
if scalar_ty is tl.bfloat16:
|
||||||
|
input = cast(input, tl.float32, builder)
|
||||||
|
|
||||||
|
# choose the right unsigned operation
|
||||||
|
if scalar_ty.is_int_unsigned():
|
||||||
|
int_op_to_unit = {
|
||||||
|
ir.REDUCE_OP.MIN: ir.REDUCE_OP.UMIN,
|
||||||
|
ir.REDUCE_OP.MAX: ir.REDUCE_OP.UMAX,
|
||||||
|
ir.REDUCE_OP.ARGMIN: ir.REDUCE_OP.ARGUMIN,
|
||||||
|
ir.REDUCE_OP.ARGMAX: ir.REDUCE_OP.ARGUMAX,
|
||||||
|
}
|
||||||
|
if INT_OP in int_op_to_unit:
|
||||||
|
INT_OP = int_op_to_unit[INT_OP]
|
||||||
|
|
||||||
# get result type
|
# get result type
|
||||||
shape = input.type.shape
|
shape = input.type.shape
|
||||||
ret_shape = []
|
ret_shape = []
|
||||||
@@ -1056,13 +1094,17 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|||||||
|
|
||||||
##
|
##
|
||||||
|
|
||||||
def multiple_of(x: tl.tensor, value: int) -> tl.tensor:
|
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||||
x.handle.multiple_of(value)
|
if len(x.shape) != len(values):
|
||||||
|
raise ValueError("Shape of input to multiple_of does not match the length of values")
|
||||||
|
x.handle.multiple_of(values)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def max_contiguous(x: tl.tensor, value: int) -> tl.tensor:
|
def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||||
x.handle.max_contiguous(value)
|
if len(x.shape) != len(values):
|
||||||
|
raise ValueError("Shape of input to max_contiguous does not match the length of values")
|
||||||
|
x.handle.max_contiguous(values)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@@ -65,7 +65,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
|||||||
# write result in-place in PROBS
|
# write result in-place in PROBS
|
||||||
dout = tl.load(DPROBS + row)
|
dout = tl.load(DPROBS + row)
|
||||||
din = (probs - delta) * dout
|
din = (probs - delta) * dout
|
||||||
tl.store(PROBS, din.to(tl.float16), mask=cols < N)
|
tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N)
|
||||||
|
|
||||||
|
|
||||||
class _cross_entropy(torch.autograd.Function):
|
class _cross_entropy(torch.autograd.Function):
|
||||||
|
@@ -26,7 +26,7 @@ def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
|
|||||||
def get_tflops(backend, device, num_ctas, num_warps, dtype):
|
def get_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||||
cc = _triton.runtime.cc(backend, device)
|
cc = _triton.runtime.cc(backend, device)
|
||||||
if cc < 80 and dtype == torch.float32:
|
if cc < 80 and dtype == torch.float32:
|
||||||
return get_simd_tflops()
|
return get_simd_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||||
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
|
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
|
import functools
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -358,6 +360,80 @@ def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clo
|
|||||||
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
|
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
|
||||||
return tflops
|
return tflops
|
||||||
|
|
||||||
|
# create decorator that wraps test function into
|
||||||
|
# a cuda-memcheck system call
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_memcheck(**target_kwargs):
|
||||||
|
def decorator(test_fn):
|
||||||
|
@functools.wraps(test_fn)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
import psutil
|
||||||
|
ppid_name = psutil.Process(os.getppid()).name()
|
||||||
|
run_cuda_memcheck = target_kwargs.items() <= kwargs.items()
|
||||||
|
if run_cuda_memcheck and ppid_name != "cuda-memcheck":
|
||||||
|
path = os.path.realpath(test_fn.__globals__["__file__"])
|
||||||
|
# get path of current file
|
||||||
|
env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"}
|
||||||
|
assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture"
|
||||||
|
test_id = kwargs['request'].node.callspec.id
|
||||||
|
cmd = f"{path}::{test_fn.__name__}[{test_id}]"
|
||||||
|
out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env)
|
||||||
|
assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed"
|
||||||
|
assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
|
||||||
|
else:
|
||||||
|
test_fn(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def nvsmi_attr(attrs):
|
||||||
|
attrs = ",".join(attrs)
|
||||||
|
cmd = [
|
||||||
|
"nvidia-smi",
|
||||||
|
"-i",
|
||||||
|
"0",
|
||||||
|
"--query-gpu=" + attrs,
|
||||||
|
"--format=csv,noheader,nounits",
|
||||||
|
]
|
||||||
|
out = subprocess.check_output(cmd)
|
||||||
|
ret = out.decode(sys.stdout.encoding).split(",")
|
||||||
|
ret = [int(x) for x in ret]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
|
||||||
|
try:
|
||||||
|
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
|
||||||
|
subprocess.check_output(
|
||||||
|
[
|
||||||
|
"nvidia-smi",
|
||||||
|
"-i",
|
||||||
|
"0",
|
||||||
|
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
subprocess.check_output(
|
||||||
|
[
|
||||||
|
"nvidia-smi",
|
||||||
|
"-i",
|
||||||
|
"0",
|
||||||
|
f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
cur_sm_clock = nvsmi_attr(["clocks.current.sm"])[0]
|
||||||
|
cur_mem_clock = nvsmi_attr(["clocks.current.memory"])[0]
|
||||||
|
assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
|
||||||
|
assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz"
|
||||||
|
tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock
|
||||||
|
gbps = 640 * 2 * ref_mem_clock * 1e-3
|
||||||
|
yield tflops, gbps
|
||||||
|
finally:
|
||||||
|
subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"])
|
||||||
|
subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"])
|
||||||
|
subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"])
|
||||||
|
|
||||||
|
|
||||||
def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None):
|
def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None):
|
||||||
if not backend:
|
if not backend:
|
||||||
|
@@ -76,8 +76,8 @@ You will specifically learn about:
|
|||||||
#
|
#
|
||||||
# .. code-block:: python
|
# .. code-block:: python
|
||||||
#
|
#
|
||||||
# pa += BLOCK_SIZE_K * stride_ak;
|
# a_ptrs += BLOCK_SIZE_K * stride_ak;
|
||||||
# pb += BLOCK_SIZE_K * stride_bk;
|
# b_ptrs += BLOCK_SIZE_K * stride_bk;
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
# L2 Cache Optimizations
|
# L2 Cache Optimizations
|
||||||
|
Reference in New Issue
Block a user