[LANG] Various (relatively minor) improvements (#320)
This commit is contained in:
6
.github/workflows/integration-tests.yml
vendored
6
.github/workflows/integration-tests.yml
vendored
@@ -18,12 +18,16 @@ jobs:
|
|||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Clear cache
|
||||||
|
run: |
|
||||||
|
rm -r /tmp/triton/
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
- name: Install Triton
|
- name: Install Triton
|
||||||
run: |
|
run: |
|
||||||
alias python='python3'
|
alias python='python3'
|
||||||
cd python
|
cd python
|
||||||
pip3 install -e .
|
pip3 install -e .
|
||||||
rm -r /tmp/triton/
|
|
||||||
|
|
||||||
- name: Unit tests
|
- name: Unit tests
|
||||||
run: |
|
run: |
|
||||||
|
@@ -537,6 +537,7 @@ void layouts::run(ir::module &mod) {
|
|||||||
tmp_[atom] = id;
|
tmp_[atom] = id;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -2197,7 +2197,8 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
|
|||||||
|
|
||||||
void generator::visit_make_range(ir::make_range* x) {
|
void generator::visit_make_range(ir::make_range* x) {
|
||||||
for(indices_t idx: idxs_.at(x)){
|
for(indices_t idx: idxs_.at(x)){
|
||||||
vals_[x][idx] = idx[0];
|
Value* start = ConstantInt::get(idx[0]->getType(), x->get_first()->get_value());
|
||||||
|
vals_[x][idx] = add(start, idx[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -875,7 +875,7 @@ make_range *make_range::create(constant_int *first, constant_int *last) {
|
|||||||
assert(first->get_type()->is_integer_ty());
|
assert(first->get_type()->is_integer_ty());
|
||||||
assert(first->get_type() == last->get_type());
|
assert(first->get_type() == last->get_type());
|
||||||
assert(((constant_int*)first)->get_value() == 0);
|
assert(((constant_int*)first)->get_value() == 0);
|
||||||
type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value()});
|
type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value() - (unsigned)first->get_value()});
|
||||||
return new make_range(ty, first, last);
|
return new make_range(ty, first, last);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -476,6 +476,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
// constants
|
// constants
|
||||||
.def("get_int1", &ir::builder::get_int1, ret::reference)
|
.def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||||
.def("get_int32", &ir::builder::get_int32, ret::reference)
|
.def("get_int32", &ir::builder::get_int32, ret::reference)
|
||||||
|
.def("get_int64", &ir::builder::get_int64, ret::reference)
|
||||||
.def("get_float16", &ir::builder::get_float16, ret::reference)
|
.def("get_float16", &ir::builder::get_float16, ret::reference)
|
||||||
.def("get_float32", &ir::builder::get_float32, ret::reference)
|
.def("get_float32", &ir::builder::get_float32, ret::reference)
|
||||||
.def("get_range", &ir::builder::get_range, ret::reference);
|
.def("get_range", &ir::builder::get_range, ret::reference);
|
||||||
|
@@ -515,6 +515,22 @@ def test_dot(epilogue, device='cuda'):
|
|||||||
assert 'ld.global.v4' in ptx
|
assert 'ld.global.v4' in ptx
|
||||||
assert 'st.global.v4' in ptx
|
assert 'st.global.v4' in ptx
|
||||||
|
|
||||||
|
# ---------------
|
||||||
|
# test arange
|
||||||
|
# ---------------
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("start", [0, 1, 7, 16])
|
||||||
|
def test_arange(start, device='cuda'):
|
||||||
|
BLOCK = 128
|
||||||
|
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
||||||
|
@triton.jit
|
||||||
|
def _kernel(z, **meta):
|
||||||
|
off = tl.arange(0, meta['BLOCK'])
|
||||||
|
val = tl.arange(meta['START'], meta['END'])
|
||||||
|
tl.store(z + off, val)
|
||||||
|
_kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK)
|
||||||
|
z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device)
|
||||||
|
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test load
|
# test load
|
||||||
|
@@ -112,7 +112,7 @@ BLOCK = 1024
|
|||||||
# test generation of random uint32
|
# test generation of random uint32
|
||||||
@pytest.mark.parametrize('size, seed',
|
@pytest.mark.parametrize('size, seed',
|
||||||
[(size, seed) for size in ['10', '4,53', '10000']\
|
[(size, seed) for size in ['10', '4,53', '10000']\
|
||||||
for seed in [0, 42, 124, 54]]
|
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||||
)
|
)
|
||||||
def test_randint(size, seed, device='cuda'):
|
def test_randint(size, seed, device='cuda'):
|
||||||
size = list(map(int, size.split(',')))
|
size = list(map(int, size.split(',')))
|
||||||
|
@@ -103,7 +103,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
arg_values = []
|
arg_values = []
|
||||||
for i, arg_name in enumerate(arg_names):
|
for i, arg_name in enumerate(arg_names):
|
||||||
if i in self.constants:
|
if i in self.constants:
|
||||||
arg_values.append(self.constants[i])
|
cst = triton.language.core._to_ir(self.constants[i], self.builder)
|
||||||
|
arg_values.append(cst)
|
||||||
else:
|
else:
|
||||||
if i in self.attributes:
|
if i in self.attributes:
|
||||||
is_ptr = fn.args[i].type.is_ptr()
|
is_ptr = fn.args[i].type.is_ptr()
|
||||||
@@ -463,9 +464,6 @@ class Kernel:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _type_name(obj):
|
def _type_name(obj):
|
||||||
type_names = {
|
type_names = {
|
||||||
int: 'I',
|
|
||||||
float: 'f',
|
|
||||||
bool: 'B',
|
|
||||||
triton.language.float8: 'f8',
|
triton.language.float8: 'f8',
|
||||||
torch.bfloat16: 'bf16',
|
torch.bfloat16: 'bf16',
|
||||||
torch.float16: 'f16',
|
torch.float16: 'f16',
|
||||||
@@ -477,12 +475,25 @@ class Kernel:
|
|||||||
torch.int32: 'i32',
|
torch.int32: 'i32',
|
||||||
torch.int64: 'i64',
|
torch.int64: 'i64',
|
||||||
}
|
}
|
||||||
return type_names[obj]
|
if hasattr(obj, 'data_ptr'):
|
||||||
|
return type_names[obj.dtype]
|
||||||
|
if isinstance(obj, int):
|
||||||
|
if abs(obj) <= 0xffffffff:
|
||||||
|
return 'I'
|
||||||
|
return 'L'
|
||||||
|
if isinstance(obj, float):
|
||||||
|
return 'f'
|
||||||
|
if isinstance(obj, bool):
|
||||||
|
return 'B'
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _to_triton_ir(context, obj):
|
def _to_triton_ir(context, obj):
|
||||||
type_map = {
|
type_map = {
|
||||||
'I': _triton.ir.type.get_int32,
|
'I': _triton.ir.type.get_int32,
|
||||||
|
'L': _triton.ir.type.get_int64,
|
||||||
'f': _triton.ir.type.get_fp32,
|
'f': _triton.ir.type.get_fp32,
|
||||||
'B': _triton.ir.type.get_int1,
|
'B': _triton.ir.type.get_int1,
|
||||||
'f8': _triton.ir.type.get_fp8,
|
'f8': _triton.ir.type.get_fp8,
|
||||||
@@ -498,11 +509,11 @@ class Kernel:
|
|||||||
}
|
}
|
||||||
# convert torch.Tensor to Triton IR pointers
|
# convert torch.Tensor to Triton IR pointers
|
||||||
if hasattr(obj, 'data_ptr'):
|
if hasattr(obj, 'data_ptr'):
|
||||||
name = Kernel._type_name(obj.dtype)
|
name = Kernel._type_name(obj)
|
||||||
elt_ty = type_map[name](context)
|
elt_ty = type_map[name](context)
|
||||||
return _triton.ir.type.make_ptr(elt_ty, 1)
|
return _triton.ir.type.make_ptr(elt_ty, 1)
|
||||||
# default path returns triton.ir.type directly
|
# default path returns triton.ir.type directly
|
||||||
name = Kernel._type_name(obj.__class__)
|
name = Kernel._type_name(obj)
|
||||||
return type_map[name](context)
|
return type_map[name](context)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -511,7 +522,7 @@ class Kernel:
|
|||||||
types_key = [None] * len(wargs)
|
types_key = [None] * len(wargs)
|
||||||
for i, arg in enumerate(wargs):
|
for i, arg in enumerate(wargs):
|
||||||
prefix = 'P' if i in tensor_idxs else ''
|
prefix = 'P' if i in tensor_idxs else ''
|
||||||
suffix = Kernel._type_name(arg.dtype) if i in tensor_idxs else Kernel._type_name(arg.__class__)
|
suffix = Kernel._type_name(arg) if i in tensor_idxs else Kernel._type_name(arg)
|
||||||
types_key[i] = prefix + suffix
|
types_key[i] = prefix + suffix
|
||||||
return tuple(types_key)
|
return tuple(types_key)
|
||||||
|
|
||||||
@@ -646,7 +657,7 @@ class Kernel:
|
|||||||
|
|
||||||
drv_cache[key] = LoadedBinary(device_idx, binary)
|
drv_cache[key] = LoadedBinary(device_idx, binary)
|
||||||
# pack arguments
|
# pack arguments
|
||||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs)])
|
||||||
params = struct.pack(fmt, *args)
|
params = struct.pack(fmt, *args)
|
||||||
# enqueue cached function into stream
|
# enqueue cached function into stream
|
||||||
callable = drv_cache[key]
|
callable = drv_cache[key]
|
||||||
|
@@ -9,7 +9,9 @@ def _to_ir(x, builder):
|
|||||||
if isinstance(x, bool):
|
if isinstance(x, bool):
|
||||||
return builder.get_int1(x)
|
return builder.get_int1(x)
|
||||||
elif isinstance(x, int):
|
elif isinstance(x, int):
|
||||||
|
if x.__abs__() <= 2**31:
|
||||||
return builder.get_int32(x)
|
return builder.get_int32(x)
|
||||||
|
return builder.get_int64(x)
|
||||||
elif isinstance(x, float):
|
elif isinstance(x, float):
|
||||||
return builder.get_float32(x)
|
return builder.get_float32(x)
|
||||||
if isinstance(x, block):
|
if isinstance(x, block):
|
||||||
@@ -636,6 +638,10 @@ def max_contiguous(input, value, _builder=None):
|
|||||||
# Standard library
|
# Standard library
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def abs(x):
|
||||||
|
return where(x >= 0, x, -x)
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def cdiv(x, div):
|
def cdiv(x, div):
|
||||||
"""
|
"""
|
||||||
|
@@ -128,7 +128,10 @@ def randint4x(seed, offset):
|
|||||||
:param offsets: The offsets to generate random numbers for.
|
:param offsets: The offsets to generate random numbers for.
|
||||||
"""
|
"""
|
||||||
z = 0
|
z = 0
|
||||||
return philox_f(offset, z, z, z, seed, z)
|
seed = hacky_to_uint64(seed) # uint will solve this
|
||||||
|
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
|
||||||
|
seed_lo = (seed & 0xffffffff).to(tl.int32)
|
||||||
|
return philox_f(offset, z, z, z, seed_lo, seed_hi)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
Reference in New Issue
Block a user