99 lines
2.7 KiB
Python
99 lines
2.7 KiB
Python
from .running_stat import RunningStat
|
|
from collections import deque
|
|
import numpy as np
|
|
|
|
class Filter(object):
|
|
def __call__(self, x, update=True):
|
|
raise NotImplementedError
|
|
def reset(self):
|
|
pass
|
|
|
|
class IdentityFilter(Filter):
|
|
def __call__(self, x, update=True):
|
|
return x
|
|
|
|
class CompositionFilter(Filter):
|
|
def __init__(self, fs):
|
|
self.fs = fs
|
|
def __call__(self, x, update=True):
|
|
for f in self.fs:
|
|
x = f(x)
|
|
return x
|
|
def output_shape(self, input_space):
|
|
out = input_space.shape
|
|
for f in self.fs:
|
|
out = f.output_shape(out)
|
|
return out
|
|
|
|
class ZFilter(Filter):
|
|
"""
|
|
y = (x-mean)/std
|
|
using running estimates of mean,std
|
|
"""
|
|
|
|
def __init__(self, shape, demean=True, destd=True, clip=10.0):
|
|
self.demean = demean
|
|
self.destd = destd
|
|
self.clip = clip
|
|
|
|
self.rs = RunningStat(shape)
|
|
|
|
def __call__(self, x, update=True):
|
|
if update: self.rs.push(x)
|
|
if self.demean:
|
|
x = x - self.rs.mean
|
|
if self.destd:
|
|
x = x / (self.rs.std+1e-8)
|
|
if self.clip:
|
|
x = np.clip(x, -self.clip, self.clip)
|
|
return x
|
|
def output_shape(self, input_space):
|
|
return input_space.shape
|
|
|
|
class AddClock(Filter):
|
|
def __init__(self):
|
|
self.count = 0
|
|
def reset(self):
|
|
self.count = 0
|
|
def __call__(self, x, update=True):
|
|
return np.append(x, self.count/100.0)
|
|
def output_shape(self, input_space):
|
|
return (input_space.shape[0]+1,)
|
|
|
|
class FlattenFilter(Filter):
|
|
def __call__(self, x, update=True):
|
|
return x.ravel()
|
|
def output_shape(self, input_space):
|
|
return (int(np.prod(input_space.shape)),)
|
|
|
|
class Ind2OneHotFilter(Filter):
|
|
def __init__(self, n):
|
|
self.n = n
|
|
def __call__(self, x, update=True):
|
|
out = np.zeros(self.n)
|
|
out[x] = 1
|
|
return out
|
|
def output_shape(self, input_space):
|
|
return (input_space.n,)
|
|
|
|
class DivFilter(Filter):
|
|
def __init__(self, divisor):
|
|
self.divisor = divisor
|
|
def __call__(self, x, update=True):
|
|
return x / self.divisor
|
|
def output_shape(self, input_space):
|
|
return input_space.shape
|
|
|
|
class StackFilter(Filter):
|
|
def __init__(self, length):
|
|
self.stack = deque(maxlen=length)
|
|
def reset(self):
|
|
self.stack.clear()
|
|
def __call__(self, x, update=True):
|
|
self.stack.append(x)
|
|
while len(self.stack) < self.stack.maxlen:
|
|
self.stack.append(x)
|
|
return np.concatenate(self.stack, axis=-1)
|
|
def output_shape(self, input_space):
|
|
return input_space.shape[:-1] + (input_space.shape[-1] * self.stack.maxlen,)
|