diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index a253e2c4c..23d460f29 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -654,9 +654,13 @@ class CompilationError(Exception): self.message = f'at {node.lineno}:{node.col_offset}:\n' self.message += '\n'.join(src.split('\n')[:node.lineno]) self.message += '\n' + ' ' * node.col_offset + '^' + self.src = src + self.node = node super().__init__(self.message) + + def __reduce__(self): # this is necessary to make CompilationError picklable - self.args = (src, node) + return (type(self), (self.src, self.node)) class OutOfResources(Exception): @@ -664,8 +668,14 @@ class OutOfResources(Exception): self.message = f'out of resource: {name}, '\ f'Required: {required}, '\ f'Hardware limit: {limit}' + self.required = required + self.limit = limit + self.name = name super().__init__(self.message) - self.args = (required, limit, name) + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return (type(self), (self.required, self.limit, self.name)) class Kernel: