[python] basic support for pytorch seems to be working

This commit is contained in:
Philippe Tillet
2019-09-05 01:32:21 -04:00
parent ed0f706005
commit 65133cdf33
2 changed files with 26 additions and 15 deletions

View File

@@ -34,6 +34,5 @@ def run_torch():
b = th.randn(K, N).cuda() b = th.randn(K, N).cuda()
th_c = th.matmul(a, b) th_c = th.matmul(a, b)
tr_c = triton.ops.dot(a, b) tr_c = triton.ops.dot(a, b)
print(c)
run_torch() run_torch()

View File

@@ -4,45 +4,49 @@ class OpContext(object):
def save_for_backward(self, *tensors): def save_for_backward(self, *tensors):
self.to_save = tensors self.to_save = tensors
def mark_dirty(self, *args):
self.dirty_tensors = args
@property @property
def saved_tensors(self): def saved_tensors(self):
return self.to_save return self.to_save
class function_meta(type): class function_meta(type):
def __init__(cls, name, bases, attrs): def __init__(cls, name, bases, attrs):
cls.contexts = dict() cls.contexts = dict()
cls.registered = False cls.registered = False
cls.framework = None
return super(function_meta, cls).__init__(name, bases, attrs) return super(function_meta, cls).__init__(name, bases, attrs)
class function(metaclass = function_meta): class function(metaclass = function_meta):
def __init__(self, framework = None):
self.framework = _find_framework(framework)
pass
@staticmethod @staticmethod
def forward(ctx, *args, **kwargs): def forward(ctx, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def apply(cls, *args, **kwargs): def apply_torch(cls, *args, **kwargs):
# call forward fw._import_torch()
class TorchFunction(fw.torch.autograd.Function):
@staticmethod
def forward(ctx, *targs, **tkwargs):
return cls.forward(ctx, *targs, **tkwargs)
@staticmethod
def backward(ctx, grad_output):
return cls.backward(ctx, grad_output)
return TorchFunction.apply(*args, **kwargs)
@classmethod
def apply_tensorflow(cls, *args, **kwargs):
fw._import_tensorflow()
ctx = OpContext() ctx = OpContext()
result = cls.forward(ctx, *args, **kwargs) result = cls.forward(ctx, *args, **kwargs)
id = result.op.get_attr('id') id = result.op.get_attr('id')
cls.contexts[id] = ctx cls.contexts[id] = ctx
# register backward # register backward
fw._import_tensorflow()
name = result.op.op_def.name name = result.op.op_def.name
if not cls.registered: if not cls.registered:
@fw.tensorflow.RegisterGradient(name) @fw.tensorflow.RegisterGradient(name)
@@ -51,4 +55,12 @@ class function(metaclass = function_meta):
return cls.backward(cls.contexts[id], dy) return cls.backward(cls.contexts[id], dy)
cls.registered = True cls.registered = True
# return result tensor # return result tensor
return result return result
@classmethod
def apply(cls, *args, **kwargs):
cls.framework = fw._find_framework(cls.framework)
if cls.framework == fw.tensorflow_id:
return cls.apply_tensorflow(*args, **kwargs)
else:
return cls.apply_torch(*args, **kwargs)