927 lines
44 KiB
Python
927 lines
44 KiB
Python
import tensorflow as tf
|
|
import numpy as np
|
|
import re
|
|
from baselines.acktr.kfac_utils import *
|
|
from functools import reduce
|
|
|
|
KFAC_OPS = ['MatMul', 'Conv2D', 'BiasAdd']
|
|
KFAC_DEBUG = False
|
|
|
|
|
|
class KfacOptimizer():
|
|
|
|
def __init__(self, learning_rate=0.01, momentum=0.9, clip_kl=0.01, kfac_update=2, stats_accum_iter=60, full_stats_init=False, cold_iter=100, cold_lr=None, async=False, async_stats=False, epsilon=1e-2, stats_decay=0.95, blockdiag_bias=False, channel_fac=False, factored_damping=False, approxT2=False, use_float64=False, weight_decay_dict={},max_grad_norm=0.5):
|
|
self.max_grad_norm = max_grad_norm
|
|
self._lr = learning_rate
|
|
self._momentum = momentum
|
|
self._clip_kl = clip_kl
|
|
self._channel_fac = channel_fac
|
|
self._kfac_update = kfac_update
|
|
self._async = async
|
|
self._async_stats = async_stats
|
|
self._epsilon = epsilon
|
|
self._stats_decay = stats_decay
|
|
self._blockdiag_bias = blockdiag_bias
|
|
self._approxT2 = approxT2
|
|
self._use_float64 = use_float64
|
|
self._factored_damping = factored_damping
|
|
self._cold_iter = cold_iter
|
|
if cold_lr == None:
|
|
# good heuristics
|
|
self._cold_lr = self._lr# * 3.
|
|
else:
|
|
self._cold_lr = cold_lr
|
|
self._stats_accum_iter = stats_accum_iter
|
|
self._weight_decay_dict = weight_decay_dict
|
|
self._diag_init_coeff = 0.
|
|
self._full_stats_init = full_stats_init
|
|
if not self._full_stats_init:
|
|
self._stats_accum_iter = self._cold_iter
|
|
|
|
self.sgd_step = tf.Variable(0, name='KFAC/sgd_step', trainable=False)
|
|
self.global_step = tf.Variable(
|
|
0, name='KFAC/global_step', trainable=False)
|
|
self.cold_step = tf.Variable(0, name='KFAC/cold_step', trainable=False)
|
|
self.factor_step = tf.Variable(
|
|
0, name='KFAC/factor_step', trainable=False)
|
|
self.stats_step = tf.Variable(
|
|
0, name='KFAC/stats_step', trainable=False)
|
|
self.vFv = tf.Variable(0., name='KFAC/vFv', trainable=False)
|
|
|
|
self.factors = {}
|
|
self.param_vars = []
|
|
self.stats = {}
|
|
self.stats_eigen = {}
|
|
|
|
def getFactors(self, g, varlist):
|
|
graph = tf.get_default_graph()
|
|
factorTensors = {}
|
|
fpropTensors = []
|
|
bpropTensors = []
|
|
opTypes = []
|
|
fops = []
|
|
|
|
def searchFactors(gradient, graph):
|
|
# hard coded search stratergy
|
|
bpropOp = gradient.op
|
|
bpropOp_name = bpropOp.name
|
|
|
|
bTensors = []
|
|
fTensors = []
|
|
|
|
# combining additive gradient, assume they are the same op type and
|
|
# indepedent
|
|
if 'AddN' in bpropOp_name:
|
|
factors = []
|
|
for g in gradient.op.inputs:
|
|
factors.append(searchFactors(g, graph))
|
|
op_names = [item['opName'] for item in factors]
|
|
# TO-DO: need to check all the attribute of the ops as well
|
|
print (gradient.name)
|
|
print (op_names)
|
|
print (len(np.unique(op_names)))
|
|
assert len(np.unique(op_names)) == 1, gradient.name + \
|
|
' is shared among different computation OPs'
|
|
|
|
bTensors = reduce(lambda x, y: x + y,
|
|
[item['bpropFactors'] for item in factors])
|
|
if len(factors[0]['fpropFactors']) > 0:
|
|
fTensors = reduce(
|
|
lambda x, y: x + y, [item['fpropFactors'] for item in factors])
|
|
fpropOp_name = op_names[0]
|
|
fpropOp = factors[0]['op']
|
|
else:
|
|
fpropOp_name = re.search(
|
|
'gradientsSampled(_[0-9]+|)/(.+?)_grad', bpropOp_name).group(2)
|
|
fpropOp = graph.get_operation_by_name(fpropOp_name)
|
|
if fpropOp.op_def.name in KFAC_OPS:
|
|
# Known OPs
|
|
###
|
|
bTensor = [
|
|
i for i in bpropOp.inputs if 'gradientsSampled' in i.name][-1]
|
|
bTensorShape = fpropOp.outputs[0].get_shape()
|
|
if bTensor.get_shape()[0].value == None:
|
|
bTensor.set_shape(bTensorShape)
|
|
bTensors.append(bTensor)
|
|
###
|
|
if fpropOp.op_def.name == 'BiasAdd':
|
|
fTensors = []
|
|
else:
|
|
fTensors.append(
|
|
[i for i in fpropOp.inputs if param.op.name not in i.name][0])
|
|
fpropOp_name = fpropOp.op_def.name
|
|
else:
|
|
# unknown OPs, block approximation used
|
|
bInputsList = [i for i in bpropOp.inputs[
|
|
0].op.inputs if 'gradientsSampled' in i.name if 'Shape' not in i.name]
|
|
if len(bInputsList) > 0:
|
|
bTensor = bInputsList[0]
|
|
bTensorShape = fpropOp.outputs[0].get_shape()
|
|
if len(bTensor.get_shape()) > 0 and bTensor.get_shape()[0].value == None:
|
|
bTensor.set_shape(bTensorShape)
|
|
bTensors.append(bTensor)
|
|
fpropOp_name = opTypes.append('UNK-' + fpropOp.op_def.name)
|
|
|
|
return {'opName': fpropOp_name, 'op': fpropOp, 'fpropFactors': fTensors, 'bpropFactors': bTensors}
|
|
|
|
for t, param in zip(g, varlist):
|
|
if KFAC_DEBUG:
|
|
print(('get factor for '+param.name))
|
|
factors = searchFactors(t, graph)
|
|
factorTensors[param] = factors
|
|
|
|
########
|
|
# check associated weights and bias for homogeneous coordinate representation
|
|
# and check redundent factors
|
|
# TO-DO: there may be a bug to detect associate bias and weights for
|
|
# forking layer, e.g. in inception models.
|
|
for param in varlist:
|
|
factorTensors[param]['assnWeights'] = None
|
|
factorTensors[param]['assnBias'] = None
|
|
for param in varlist:
|
|
if factorTensors[param]['opName'] == 'BiasAdd':
|
|
factorTensors[param]['assnWeights'] = None
|
|
for item in varlist:
|
|
if len(factorTensors[item]['bpropFactors']) > 0:
|
|
if (set(factorTensors[item]['bpropFactors']) == set(factorTensors[param]['bpropFactors'])) and (len(factorTensors[item]['fpropFactors']) > 0):
|
|
factorTensors[param]['assnWeights'] = item
|
|
factorTensors[item]['assnBias'] = param
|
|
factorTensors[param]['bpropFactors'] = factorTensors[
|
|
item]['bpropFactors']
|
|
|
|
########
|
|
|
|
########
|
|
# concatenate the additive gradients along the batch dimension, i.e.
|
|
# assuming independence structure
|
|
for key in ['fpropFactors', 'bpropFactors']:
|
|
for i, param in enumerate(varlist):
|
|
if len(factorTensors[param][key]) > 0:
|
|
if (key + '_concat') not in factorTensors[param]:
|
|
name_scope = factorTensors[param][key][0].name.split(':')[
|
|
0]
|
|
with tf.name_scope(name_scope):
|
|
factorTensors[param][
|
|
key + '_concat'] = tf.concat(factorTensors[param][key], 0)
|
|
else:
|
|
factorTensors[param][key + '_concat'] = None
|
|
for j, param2 in enumerate(varlist[(i + 1):]):
|
|
if (len(factorTensors[param][key]) > 0) and (set(factorTensors[param2][key]) == set(factorTensors[param][key])):
|
|
factorTensors[param2][key] = factorTensors[param][key]
|
|
factorTensors[param2][
|
|
key + '_concat'] = factorTensors[param][key + '_concat']
|
|
########
|
|
|
|
if KFAC_DEBUG:
|
|
for items in zip(varlist, fpropTensors, bpropTensors, opTypes):
|
|
print((items[0].name, factorTensors[item]))
|
|
self.factors = factorTensors
|
|
return factorTensors
|
|
|
|
def getStats(self, factors, varlist):
|
|
if len(self.stats) == 0:
|
|
# initialize stats variables on CPU because eigen decomp is
|
|
# computed on CPU
|
|
with tf.device('/cpu'):
|
|
tmpStatsCache = {}
|
|
|
|
# search for tensor factors and
|
|
# use block diag approx for the bias units
|
|
for var in varlist:
|
|
fpropFactor = factors[var]['fpropFactors_concat']
|
|
bpropFactor = factors[var]['bpropFactors_concat']
|
|
opType = factors[var]['opName']
|
|
if opType == 'Conv2D':
|
|
Kh = var.get_shape()[0]
|
|
Kw = var.get_shape()[1]
|
|
C = fpropFactor.get_shape()[-1]
|
|
|
|
Oh = bpropFactor.get_shape()[1]
|
|
Ow = bpropFactor.get_shape()[2]
|
|
if Oh == 1 and Ow == 1 and self._channel_fac:
|
|
# factorization along the channels do not support
|
|
# homogeneous coordinate
|
|
var_assnBias = factors[var]['assnBias']
|
|
if var_assnBias:
|
|
factors[var]['assnBias'] = None
|
|
factors[var_assnBias]['assnWeights'] = None
|
|
##
|
|
|
|
for var in varlist:
|
|
fpropFactor = factors[var]['fpropFactors_concat']
|
|
bpropFactor = factors[var]['bpropFactors_concat']
|
|
opType = factors[var]['opName']
|
|
self.stats[var] = {'opName': opType,
|
|
'fprop_concat_stats': [],
|
|
'bprop_concat_stats': [],
|
|
'assnWeights': factors[var]['assnWeights'],
|
|
'assnBias': factors[var]['assnBias'],
|
|
}
|
|
if fpropFactor is not None:
|
|
if fpropFactor not in tmpStatsCache:
|
|
if opType == 'Conv2D':
|
|
Kh = var.get_shape()[0]
|
|
Kw = var.get_shape()[1]
|
|
C = fpropFactor.get_shape()[-1]
|
|
|
|
Oh = bpropFactor.get_shape()[1]
|
|
Ow = bpropFactor.get_shape()[2]
|
|
if Oh == 1 and Ow == 1 and self._channel_fac:
|
|
# factorization along the channels
|
|
# assume independence between input channels and spatial
|
|
# 2K-1 x 2K-1 covariance matrix and C x C covariance matrix
|
|
# factorization along the channels do not
|
|
# support homogeneous coordinate, assnBias
|
|
# is always None
|
|
fpropFactor2_size = Kh * Kw
|
|
slot_fpropFactor_stats2 = tf.Variable(tf.diag(tf.ones(
|
|
[fpropFactor2_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False)
|
|
self.stats[var]['fprop_concat_stats'].append(
|
|
slot_fpropFactor_stats2)
|
|
|
|
fpropFactor_size = C
|
|
else:
|
|
# 2K-1 x 2K-1 x C x C covariance matrix
|
|
# assume BHWC
|
|
fpropFactor_size = Kh * Kw * C
|
|
else:
|
|
# D x D covariance matrix
|
|
fpropFactor_size = fpropFactor.get_shape()[-1]
|
|
|
|
# use homogeneous coordinate
|
|
if not self._blockdiag_bias and self.stats[var]['assnBias']:
|
|
fpropFactor_size += 1
|
|
|
|
slot_fpropFactor_stats = tf.Variable(tf.diag(tf.ones(
|
|
[fpropFactor_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False)
|
|
self.stats[var]['fprop_concat_stats'].append(
|
|
slot_fpropFactor_stats)
|
|
if opType != 'Conv2D':
|
|
tmpStatsCache[fpropFactor] = self.stats[
|
|
var]['fprop_concat_stats']
|
|
else:
|
|
self.stats[var][
|
|
'fprop_concat_stats'] = tmpStatsCache[fpropFactor]
|
|
|
|
if bpropFactor is not None:
|
|
# no need to collect backward stats for bias vectors if
|
|
# using homogeneous coordinates
|
|
if not((not self._blockdiag_bias) and self.stats[var]['assnWeights']):
|
|
if bpropFactor not in tmpStatsCache:
|
|
slot_bpropFactor_stats = tf.Variable(tf.diag(tf.ones([bpropFactor.get_shape(
|
|
)[-1]])) * self._diag_init_coeff, name='KFAC_STATS/' + bpropFactor.op.name, trainable=False)
|
|
self.stats[var]['bprop_concat_stats'].append(
|
|
slot_bpropFactor_stats)
|
|
tmpStatsCache[bpropFactor] = self.stats[
|
|
var]['bprop_concat_stats']
|
|
else:
|
|
self.stats[var][
|
|
'bprop_concat_stats'] = tmpStatsCache[bpropFactor]
|
|
|
|
return self.stats
|
|
|
|
def compute_and_apply_stats(self, loss_sampled, var_list=None):
|
|
varlist = var_list
|
|
if varlist is None:
|
|
varlist = tf.trainable_variables()
|
|
|
|
stats = self.compute_stats(loss_sampled, var_list=varlist)
|
|
return self.apply_stats(stats)
|
|
|
|
def compute_stats(self, loss_sampled, var_list=None):
|
|
varlist = var_list
|
|
if varlist is None:
|
|
varlist = tf.trainable_variables()
|
|
|
|
gs = tf.gradients(loss_sampled, varlist, name='gradientsSampled')
|
|
self.gs = gs
|
|
factors = self.getFactors(gs, varlist)
|
|
stats = self.getStats(factors, varlist)
|
|
|
|
updateOps = []
|
|
statsUpdates = {}
|
|
statsUpdates_cache = {}
|
|
for var in varlist:
|
|
opType = factors[var]['opName']
|
|
fops = factors[var]['op']
|
|
fpropFactor = factors[var]['fpropFactors_concat']
|
|
fpropStats_vars = stats[var]['fprop_concat_stats']
|
|
bpropFactor = factors[var]['bpropFactors_concat']
|
|
bpropStats_vars = stats[var]['bprop_concat_stats']
|
|
SVD_factors = {}
|
|
for stats_var in fpropStats_vars:
|
|
stats_var_dim = int(stats_var.get_shape()[0])
|
|
if stats_var not in statsUpdates_cache:
|
|
old_fpropFactor = fpropFactor
|
|
B = (tf.shape(fpropFactor)[0]) # batch size
|
|
if opType == 'Conv2D':
|
|
strides = fops.get_attr("strides")
|
|
padding = fops.get_attr("padding")
|
|
convkernel_size = var.get_shape()[0:3]
|
|
|
|
KH = int(convkernel_size[0])
|
|
KW = int(convkernel_size[1])
|
|
C = int(convkernel_size[2])
|
|
flatten_size = int(KH * KW * C)
|
|
|
|
Oh = int(bpropFactor.get_shape()[1])
|
|
Ow = int(bpropFactor.get_shape()[2])
|
|
|
|
if Oh == 1 and Ow == 1 and self._channel_fac:
|
|
# factorization along the channels
|
|
# assume independence among input channels
|
|
# factor = B x 1 x 1 x (KH xKW x C)
|
|
# patches = B x Oh x Ow x (KH xKW x C)
|
|
if len(SVD_factors) == 0:
|
|
if KFAC_DEBUG:
|
|
print(('approx %s act factor with rank-1 SVD factors' % (var.name)))
|
|
# find closest rank-1 approx to the feature map
|
|
S, U, V = tf.batch_svd(tf.reshape(
|
|
fpropFactor, [-1, KH * KW, C]))
|
|
# get rank-1 approx slides
|
|
sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1)
|
|
patches_k = U[:, :, 0] * sqrtS1 # B x KH*KW
|
|
full_factor_shape = fpropFactor.get_shape()
|
|
patches_k.set_shape(
|
|
[full_factor_shape[0], KH * KW])
|
|
patches_c = V[:, :, 0] * sqrtS1 # B x C
|
|
patches_c.set_shape([full_factor_shape[0], C])
|
|
SVD_factors[C] = patches_c
|
|
SVD_factors[KH * KW] = patches_k
|
|
fpropFactor = SVD_factors[stats_var_dim]
|
|
|
|
else:
|
|
# poor mem usage implementation
|
|
patches = tf.extract_image_patches(fpropFactor, ksizes=[1, convkernel_size[
|
|
0], convkernel_size[1], 1], strides=strides, rates=[1, 1, 1, 1], padding=padding)
|
|
|
|
if self._approxT2:
|
|
if KFAC_DEBUG:
|
|
print(('approxT2 act fisher for %s' % (var.name)))
|
|
# T^2 terms * 1/T^2, size: B x C
|
|
fpropFactor = tf.reduce_mean(patches, [1, 2])
|
|
else:
|
|
# size: (B x Oh x Ow) x C
|
|
fpropFactor = tf.reshape(
|
|
patches, [-1, flatten_size]) / Oh / Ow
|
|
fpropFactor_size = int(fpropFactor.get_shape()[-1])
|
|
if stats_var_dim == (fpropFactor_size + 1) and not self._blockdiag_bias:
|
|
if opType == 'Conv2D' and not self._approxT2:
|
|
# correct padding for numerical stability (we
|
|
# divided out OhxOw from activations for T1 approx)
|
|
fpropFactor = tf.concat([fpropFactor, tf.ones(
|
|
[tf.shape(fpropFactor)[0], 1]) / Oh / Ow], 1)
|
|
else:
|
|
# use homogeneous coordinates
|
|
fpropFactor = tf.concat(
|
|
[fpropFactor, tf.ones([tf.shape(fpropFactor)[0], 1])], 1)
|
|
|
|
# average over the number of data points in a batch
|
|
# divided by B
|
|
cov = tf.matmul(fpropFactor, fpropFactor,
|
|
transpose_a=True) / tf.cast(B, tf.float32)
|
|
updateOps.append(cov)
|
|
statsUpdates[stats_var] = cov
|
|
if opType != 'Conv2D':
|
|
# HACK: for convolution we recompute fprop stats for
|
|
# every layer including forking layers
|
|
statsUpdates_cache[stats_var] = cov
|
|
|
|
for stats_var in bpropStats_vars:
|
|
stats_var_dim = int(stats_var.get_shape()[0])
|
|
if stats_var not in statsUpdates_cache:
|
|
old_bpropFactor = bpropFactor
|
|
bpropFactor_shape = bpropFactor.get_shape()
|
|
B = tf.shape(bpropFactor)[0] # batch size
|
|
C = int(bpropFactor_shape[-1]) # num channels
|
|
if opType == 'Conv2D' or len(bpropFactor_shape) == 4:
|
|
if fpropFactor is not None:
|
|
if self._approxT2:
|
|
if KFAC_DEBUG:
|
|
print(('approxT2 grad fisher for %s' % (var.name)))
|
|
bpropFactor = tf.reduce_sum(
|
|
bpropFactor, [1, 2]) # T^2 terms * 1/T^2
|
|
else:
|
|
bpropFactor = tf.reshape(
|
|
bpropFactor, [-1, C]) * Oh * Ow # T * 1/T terms
|
|
else:
|
|
# just doing block diag approx. spatial independent
|
|
# structure does not apply here. summing over
|
|
# spatial locations
|
|
if KFAC_DEBUG:
|
|
print(('block diag approx fisher for %s' % (var.name)))
|
|
bpropFactor = tf.reduce_sum(bpropFactor, [1, 2])
|
|
|
|
# assume sampled loss is averaged. TO-DO:figure out better
|
|
# way to handle this
|
|
bpropFactor *= tf.to_float(B)
|
|
##
|
|
|
|
cov_b = tf.matmul(
|
|
bpropFactor, bpropFactor, transpose_a=True) / tf.to_float(tf.shape(bpropFactor)[0])
|
|
|
|
updateOps.append(cov_b)
|
|
statsUpdates[stats_var] = cov_b
|
|
statsUpdates_cache[stats_var] = cov_b
|
|
|
|
if KFAC_DEBUG:
|
|
aKey = list(statsUpdates.keys())[0]
|
|
statsUpdates[aKey] = tf.Print(statsUpdates[aKey],
|
|
[tf.convert_to_tensor('step:'),
|
|
self.global_step,
|
|
tf.convert_to_tensor(
|
|
'computing stats'),
|
|
])
|
|
self.statsUpdates = statsUpdates
|
|
return statsUpdates
|
|
|
|
def apply_stats(self, statsUpdates):
|
|
""" compute stats and update/apply the new stats to the running average
|
|
"""
|
|
|
|
def updateAccumStats():
|
|
if self._full_stats_init:
|
|
return tf.cond(tf.greater(self.sgd_step, self._cold_iter), lambda: tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter)), tf.no_op)
|
|
else:
|
|
return tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter))
|
|
|
|
def updateRunningAvgStats(statsUpdates, fac_iter=1):
|
|
# return tf.cond(tf.greater_equal(self.factor_step,
|
|
# tf.convert_to_tensor(fac_iter)), lambda:
|
|
# tf.group(*self._apply_stats(stats_list, varlist)), tf.no_op)
|
|
return tf.group(*self._apply_stats(statsUpdates))
|
|
|
|
if self._async_stats:
|
|
# asynchronous stats update
|
|
update_stats = self._apply_stats(statsUpdates)
|
|
|
|
queue = tf.FIFOQueue(1, [item.dtype for item in update_stats], shapes=[
|
|
item.get_shape() for item in update_stats])
|
|
enqueue_op = queue.enqueue(update_stats)
|
|
|
|
def dequeue_stats_op():
|
|
return queue.dequeue()
|
|
self.qr_stats = tf.train.QueueRunner(queue, [enqueue_op])
|
|
update_stats_op = tf.cond(tf.equal(queue.size(), tf.convert_to_tensor(
|
|
0)), tf.no_op, lambda: tf.group(*[dequeue_stats_op(), ]))
|
|
else:
|
|
# synchronous stats update
|
|
update_stats_op = tf.cond(tf.greater_equal(
|
|
self.stats_step, self._stats_accum_iter), lambda: updateRunningAvgStats(statsUpdates), updateAccumStats)
|
|
self._update_stats_op = update_stats_op
|
|
return update_stats_op
|
|
|
|
def _apply_stats(self, statsUpdates, accumulate=False, accumulateCoeff=0.):
|
|
updateOps = []
|
|
# obtain the stats var list
|
|
for stats_var in statsUpdates:
|
|
stats_new = statsUpdates[stats_var]
|
|
if accumulate:
|
|
# simple superbatch averaging
|
|
update_op = tf.assign_add(
|
|
stats_var, accumulateCoeff * stats_new, use_locking=True)
|
|
else:
|
|
# exponential running averaging
|
|
update_op = tf.assign(
|
|
stats_var, stats_var * self._stats_decay, use_locking=True)
|
|
update_op = tf.assign_add(
|
|
update_op, (1. - self._stats_decay) * stats_new, use_locking=True)
|
|
updateOps.append(update_op)
|
|
|
|
with tf.control_dependencies(updateOps):
|
|
stats_step_op = tf.assign_add(self.stats_step, 1)
|
|
|
|
if KFAC_DEBUG:
|
|
stats_step_op = (tf.Print(stats_step_op,
|
|
[tf.convert_to_tensor('step:'),
|
|
self.global_step,
|
|
tf.convert_to_tensor('fac step:'),
|
|
self.factor_step,
|
|
tf.convert_to_tensor('sgd step:'),
|
|
self.sgd_step,
|
|
tf.convert_to_tensor('Accum:'),
|
|
tf.convert_to_tensor(accumulate),
|
|
tf.convert_to_tensor('Accum coeff:'),
|
|
tf.convert_to_tensor(accumulateCoeff),
|
|
tf.convert_to_tensor('stat step:'),
|
|
self.stats_step, updateOps[0], updateOps[1]]))
|
|
return [stats_step_op, ]
|
|
|
|
def getStatsEigen(self, stats=None):
|
|
if len(self.stats_eigen) == 0:
|
|
stats_eigen = {}
|
|
if stats is None:
|
|
stats = self.stats
|
|
|
|
tmpEigenCache = {}
|
|
with tf.device('/cpu:0'):
|
|
for var in stats:
|
|
for key in ['fprop_concat_stats', 'bprop_concat_stats']:
|
|
for stats_var in stats[var][key]:
|
|
if stats_var not in tmpEigenCache:
|
|
stats_dim = stats_var.get_shape()[1].value
|
|
e = tf.Variable(tf.ones(
|
|
[stats_dim]), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/e', trainable=False)
|
|
Q = tf.Variable(tf.diag(tf.ones(
|
|
[stats_dim])), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/Q', trainable=False)
|
|
stats_eigen[stats_var] = {'e': e, 'Q': Q}
|
|
tmpEigenCache[
|
|
stats_var] = stats_eigen[stats_var]
|
|
else:
|
|
stats_eigen[stats_var] = tmpEigenCache[
|
|
stats_var]
|
|
self.stats_eigen = stats_eigen
|
|
return self.stats_eigen
|
|
|
|
def computeStatsEigen(self):
|
|
""" compute the eigen decomp using copied var stats to avoid concurrent read/write from other queue """
|
|
# TO-DO: figure out why this op has delays (possibly moving
|
|
# eigenvectors around?)
|
|
with tf.device('/cpu:0'):
|
|
def removeNone(tensor_list):
|
|
local_list = []
|
|
for item in tensor_list:
|
|
if item is not None:
|
|
local_list.append(item)
|
|
return local_list
|
|
|
|
def copyStats(var_list):
|
|
print("copying stats to buffer tensors before eigen decomp")
|
|
redundant_stats = {}
|
|
copied_list = []
|
|
for item in var_list:
|
|
if item is not None:
|
|
if item not in redundant_stats:
|
|
if self._use_float64:
|
|
redundant_stats[item] = tf.cast(
|
|
tf.identity(item), tf.float64)
|
|
else:
|
|
redundant_stats[item] = tf.identity(item)
|
|
copied_list.append(redundant_stats[item])
|
|
else:
|
|
copied_list.append(None)
|
|
return copied_list
|
|
#stats = [copyStats(self.fStats), copyStats(self.bStats)]
|
|
#stats = [self.fStats, self.bStats]
|
|
|
|
stats_eigen = self.stats_eigen
|
|
computedEigen = {}
|
|
eigen_reverse_lookup = {}
|
|
updateOps = []
|
|
# sync copied stats
|
|
# with tf.control_dependencies(removeNone(stats[0]) +
|
|
# removeNone(stats[1])):
|
|
with tf.control_dependencies([]):
|
|
for stats_var in stats_eigen:
|
|
if stats_var not in computedEigen:
|
|
eigens = tf.self_adjoint_eig(stats_var)
|
|
e = eigens[0]
|
|
Q = eigens[1]
|
|
if self._use_float64:
|
|
e = tf.cast(e, tf.float32)
|
|
Q = tf.cast(Q, tf.float32)
|
|
updateOps.append(e)
|
|
updateOps.append(Q)
|
|
computedEigen[stats_var] = {'e': e, 'Q': Q}
|
|
eigen_reverse_lookup[e] = stats_eigen[stats_var]['e']
|
|
eigen_reverse_lookup[Q] = stats_eigen[stats_var]['Q']
|
|
|
|
self.eigen_reverse_lookup = eigen_reverse_lookup
|
|
self.eigen_update_list = updateOps
|
|
|
|
if KFAC_DEBUG:
|
|
self.eigen_update_list = [item for item in updateOps]
|
|
with tf.control_dependencies(updateOps):
|
|
updateOps.append(tf.Print(tf.constant(
|
|
0.), [tf.convert_to_tensor('computed factor eigen')]))
|
|
|
|
return updateOps
|
|
|
|
def applyStatsEigen(self, eigen_list):
|
|
updateOps = []
|
|
print(('updating %d eigenvalue/vectors' % len(eigen_list)))
|
|
for i, (tensor, mark) in enumerate(zip(eigen_list, self.eigen_update_list)):
|
|
stats_eigen_var = self.eigen_reverse_lookup[mark]
|
|
updateOps.append(
|
|
tf.assign(stats_eigen_var, tensor, use_locking=True))
|
|
|
|
with tf.control_dependencies(updateOps):
|
|
factor_step_op = tf.assign_add(self.factor_step, 1)
|
|
updateOps.append(factor_step_op)
|
|
if KFAC_DEBUG:
|
|
updateOps.append(tf.Print(tf.constant(
|
|
0.), [tf.convert_to_tensor('updated kfac factors')]))
|
|
return updateOps
|
|
|
|
def getKfacPrecondUpdates(self, gradlist, varlist):
|
|
updatelist = []
|
|
vg = 0.
|
|
|
|
assert len(self.stats) > 0
|
|
assert len(self.stats_eigen) > 0
|
|
assert len(self.factors) > 0
|
|
counter = 0
|
|
|
|
grad_dict = {var: grad for grad, var in zip(gradlist, varlist)}
|
|
|
|
for grad, var in zip(gradlist, varlist):
|
|
GRAD_RESHAPE = False
|
|
GRAD_TRANSPOSE = False
|
|
|
|
fpropFactoredFishers = self.stats[var]['fprop_concat_stats']
|
|
bpropFactoredFishers = self.stats[var]['bprop_concat_stats']
|
|
|
|
if (len(fpropFactoredFishers) + len(bpropFactoredFishers)) > 0:
|
|
counter += 1
|
|
GRAD_SHAPE = grad.get_shape()
|
|
if len(grad.get_shape()) > 2:
|
|
# reshape conv kernel parameters
|
|
KW = int(grad.get_shape()[0])
|
|
KH = int(grad.get_shape()[1])
|
|
C = int(grad.get_shape()[2])
|
|
D = int(grad.get_shape()[3])
|
|
|
|
if len(fpropFactoredFishers) > 1 and self._channel_fac:
|
|
# reshape conv kernel parameters into tensor
|
|
grad = tf.reshape(grad, [KW * KH, C, D])
|
|
else:
|
|
# reshape conv kernel parameters into 2D grad
|
|
grad = tf.reshape(grad, [-1, D])
|
|
GRAD_RESHAPE = True
|
|
elif len(grad.get_shape()) == 1:
|
|
# reshape bias or 1D parameters
|
|
D = int(grad.get_shape()[0])
|
|
|
|
grad = tf.expand_dims(grad, 0)
|
|
GRAD_RESHAPE = True
|
|
else:
|
|
# 2D parameters
|
|
C = int(grad.get_shape()[0])
|
|
D = int(grad.get_shape()[1])
|
|
|
|
if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias:
|
|
# use homogeneous coordinates only works for 2D grad.
|
|
# TO-DO: figure out how to factorize bias grad
|
|
# stack bias grad
|
|
var_assnBias = self.stats[var]['assnBias']
|
|
grad = tf.concat(
|
|
[grad, tf.expand_dims(grad_dict[var_assnBias], 0)], 0)
|
|
|
|
# project gradient to eigen space and reshape the eigenvalues
|
|
# for broadcasting
|
|
eigVals = []
|
|
|
|
for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']):
|
|
Q = self.stats_eigen[stats]['Q']
|
|
e = detectMinVal(self.stats_eigen[stats][
|
|
'e'], var, name='act', debug=KFAC_DEBUG)
|
|
|
|
Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='act')
|
|
eigVals.append(e)
|
|
grad = gmatmul(Q, grad, transpose_a=True, reduce_dim=idx)
|
|
|
|
for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']):
|
|
Q = self.stats_eigen[stats]['Q']
|
|
e = detectMinVal(self.stats_eigen[stats][
|
|
'e'], var, name='grad', debug=KFAC_DEBUG)
|
|
|
|
Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='grad')
|
|
eigVals.append(e)
|
|
grad = gmatmul(grad, Q, transpose_b=False, reduce_dim=idx)
|
|
##
|
|
|
|
#####
|
|
# whiten using eigenvalues
|
|
weightDecayCoeff = 0.
|
|
if var in self._weight_decay_dict:
|
|
weightDecayCoeff = self._weight_decay_dict[var]
|
|
if KFAC_DEBUG:
|
|
print(('weight decay coeff for %s is %f' % (var.name, weightDecayCoeff)))
|
|
|
|
if self._factored_damping:
|
|
if KFAC_DEBUG:
|
|
print(('use factored damping for %s' % (var.name)))
|
|
coeffs = 1.
|
|
num_factors = len(eigVals)
|
|
# compute the ratio of two trace norm of the left and right
|
|
# KFac matrices, and their generalization
|
|
if len(eigVals) == 1:
|
|
damping = self._epsilon + weightDecayCoeff
|
|
else:
|
|
damping = tf.pow(
|
|
self._epsilon + weightDecayCoeff, 1. / num_factors)
|
|
eigVals_tnorm_avg = [tf.reduce_mean(
|
|
tf.abs(e)) for e in eigVals]
|
|
for e, e_tnorm in zip(eigVals, eigVals_tnorm_avg):
|
|
eig_tnorm_negList = [
|
|
item for item in eigVals_tnorm_avg if item != e_tnorm]
|
|
if len(eigVals) == 1:
|
|
adjustment = 1.
|
|
elif len(eigVals) == 2:
|
|
adjustment = tf.sqrt(
|
|
e_tnorm / eig_tnorm_negList[0])
|
|
else:
|
|
eig_tnorm_negList_prod = reduce(
|
|
lambda x, y: x * y, eig_tnorm_negList)
|
|
adjustment = tf.pow(
|
|
tf.pow(e_tnorm, num_factors - 1.) / eig_tnorm_negList_prod, 1. / num_factors)
|
|
coeffs *= (e + adjustment * damping)
|
|
else:
|
|
coeffs = 1.
|
|
damping = (self._epsilon + weightDecayCoeff)
|
|
for e in eigVals:
|
|
coeffs *= e
|
|
coeffs += damping
|
|
|
|
#grad = tf.Print(grad, [tf.convert_to_tensor('1'), tf.convert_to_tensor(var.name), grad.get_shape()])
|
|
|
|
grad /= coeffs
|
|
|
|
#grad = tf.Print(grad, [tf.convert_to_tensor('2'), tf.convert_to_tensor(var.name), grad.get_shape()])
|
|
#####
|
|
# project gradient back to euclidean space
|
|
for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']):
|
|
Q = self.stats_eigen[stats]['Q']
|
|
grad = gmatmul(Q, grad, transpose_a=False, reduce_dim=idx)
|
|
|
|
for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']):
|
|
Q = self.stats_eigen[stats]['Q']
|
|
grad = gmatmul(grad, Q, transpose_b=True, reduce_dim=idx)
|
|
##
|
|
|
|
#grad = tf.Print(grad, [tf.convert_to_tensor('3'), tf.convert_to_tensor(var.name), grad.get_shape()])
|
|
if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias:
|
|
# use homogeneous coordinates only works for 2D grad.
|
|
# TO-DO: figure out how to factorize bias grad
|
|
# un-stack bias grad
|
|
var_assnBias = self.stats[var]['assnBias']
|
|
C_plus_one = int(grad.get_shape()[0])
|
|
grad_assnBias = tf.reshape(tf.slice(grad,
|
|
begin=[
|
|
C_plus_one - 1, 0],
|
|
size=[1, -1]), var_assnBias.get_shape())
|
|
grad_assnWeights = tf.slice(grad,
|
|
begin=[0, 0],
|
|
size=[C_plus_one - 1, -1])
|
|
grad_dict[var_assnBias] = grad_assnBias
|
|
grad = grad_assnWeights
|
|
|
|
#grad = tf.Print(grad, [tf.convert_to_tensor('4'), tf.convert_to_tensor(var.name), grad.get_shape()])
|
|
if GRAD_RESHAPE:
|
|
grad = tf.reshape(grad, GRAD_SHAPE)
|
|
|
|
grad_dict[var] = grad
|
|
|
|
print(('projecting %d gradient matrices' % counter))
|
|
|
|
for g, var in zip(gradlist, varlist):
|
|
grad = grad_dict[var]
|
|
### clipping ###
|
|
if KFAC_DEBUG:
|
|
print(('apply clipping to %s' % (var.name)))
|
|
tf.Print(grad, [tf.sqrt(tf.reduce_sum(tf.pow(grad, 2)))], "Euclidean norm of new grad")
|
|
local_vg = tf.reduce_sum(grad * g * (self._lr * self._lr))
|
|
vg += local_vg
|
|
|
|
# recale everything
|
|
if KFAC_DEBUG:
|
|
print('apply vFv clipping')
|
|
|
|
scaling = tf.minimum(1., tf.sqrt(self._clip_kl / vg))
|
|
if KFAC_DEBUG:
|
|
scaling = tf.Print(scaling, [tf.convert_to_tensor(
|
|
'clip: '), scaling, tf.convert_to_tensor(' vFv: '), vg])
|
|
with tf.control_dependencies([tf.assign(self.vFv, vg)]):
|
|
updatelist = [grad_dict[var] for var in varlist]
|
|
for i, item in enumerate(updatelist):
|
|
updatelist[i] = scaling * item
|
|
|
|
return updatelist
|
|
|
|
def compute_gradients(self, loss, var_list=None):
|
|
varlist = var_list
|
|
if varlist is None:
|
|
varlist = tf.trainable_variables()
|
|
g = tf.gradients(loss, varlist)
|
|
|
|
return [(a, b) for a, b in zip(g, varlist)]
|
|
|
|
def apply_gradients_kfac(self, grads):
|
|
g, varlist = list(zip(*grads))
|
|
|
|
if len(self.stats_eigen) == 0:
|
|
self.getStatsEigen()
|
|
|
|
qr = None
|
|
# launch eigen-decomp on a queue thread
|
|
if self._async:
|
|
print('Use async eigen decomp')
|
|
# get a list of factor loading tensors
|
|
factorOps_dummy = self.computeStatsEigen()
|
|
|
|
# define a queue for the list of factor loading tensors
|
|
queue = tf.FIFOQueue(1, [item.dtype for item in factorOps_dummy], shapes=[
|
|
item.get_shape() for item in factorOps_dummy])
|
|
enqueue_op = tf.cond(tf.logical_and(tf.equal(tf.mod(self.stats_step, self._kfac_update), tf.convert_to_tensor(
|
|
0)), tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: queue.enqueue(self.computeStatsEigen()), tf.no_op)
|
|
|
|
def dequeue_op():
|
|
return queue.dequeue()
|
|
|
|
qr = tf.train.QueueRunner(queue, [enqueue_op])
|
|
|
|
updateOps = []
|
|
global_step_op = tf.assign_add(self.global_step, 1)
|
|
updateOps.append(global_step_op)
|
|
|
|
with tf.control_dependencies([global_step_op]):
|
|
|
|
# compute updates
|
|
assert self._update_stats_op != None
|
|
updateOps.append(self._update_stats_op)
|
|
dependency_list = []
|
|
if not self._async:
|
|
dependency_list.append(self._update_stats_op)
|
|
|
|
with tf.control_dependencies(dependency_list):
|
|
def no_op_wrapper():
|
|
return tf.group(*[tf.assign_add(self.cold_step, 1)])
|
|
|
|
if not self._async:
|
|
# synchronous eigen-decomp updates
|
|
updateFactorOps = tf.cond(tf.logical_and(tf.equal(tf.mod(self.stats_step, self._kfac_update),
|
|
tf.convert_to_tensor(0)),
|
|
tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: tf.group(*self.applyStatsEigen(self.computeStatsEigen())), no_op_wrapper)
|
|
else:
|
|
# asynchronous eigen-decomp updates using queue
|
|
updateFactorOps = tf.cond(tf.greater_equal(self.stats_step, self._stats_accum_iter),
|
|
lambda: tf.cond(tf.equal(queue.size(), tf.convert_to_tensor(0)),
|
|
tf.no_op,
|
|
|
|
lambda: tf.group(
|
|
*self.applyStatsEigen(dequeue_op())),
|
|
),
|
|
no_op_wrapper)
|
|
|
|
updateOps.append(updateFactorOps)
|
|
|
|
with tf.control_dependencies([updateFactorOps]):
|
|
def gradOp():
|
|
return list(g)
|
|
|
|
def getKfacGradOp():
|
|
return self.getKfacPrecondUpdates(g, varlist)
|
|
u = tf.cond(tf.greater(self.factor_step,
|
|
tf.convert_to_tensor(0)), getKfacGradOp, gradOp)
|
|
|
|
optim = tf.train.MomentumOptimizer(
|
|
self._lr * (1. - self._momentum), self._momentum)
|
|
#optim = tf.train.AdamOptimizer(self._lr, epsilon=0.01)
|
|
|
|
def optimOp():
|
|
def updateOptimOp():
|
|
if self._full_stats_init:
|
|
return tf.cond(tf.greater(self.factor_step, tf.convert_to_tensor(0)), lambda: optim.apply_gradients(list(zip(u, varlist))), tf.no_op)
|
|
else:
|
|
return optim.apply_gradients(list(zip(u, varlist)))
|
|
if self._full_stats_init:
|
|
return tf.cond(tf.greater_equal(self.stats_step, self._stats_accum_iter), updateOptimOp, tf.no_op)
|
|
else:
|
|
return tf.cond(tf.greater_equal(self.sgd_step, self._cold_iter), updateOptimOp, tf.no_op)
|
|
updateOps.append(optimOp())
|
|
|
|
return tf.group(*updateOps), qr
|
|
|
|
def apply_gradients(self, grads):
|
|
coldOptim = tf.train.MomentumOptimizer(
|
|
self._cold_lr, self._momentum)
|
|
|
|
def coldSGDstart():
|
|
sgd_grads, sgd_var = zip(*grads)
|
|
|
|
if self.max_grad_norm != None:
|
|
sgd_grads, sgd_grad_norm = tf.clip_by_global_norm(sgd_grads,self.max_grad_norm)
|
|
|
|
sgd_grads = list(zip(sgd_grads,sgd_var))
|
|
|
|
sgd_step_op = tf.assign_add(self.sgd_step, 1)
|
|
coldOptim_op = coldOptim.apply_gradients(sgd_grads)
|
|
if KFAC_DEBUG:
|
|
with tf.control_dependencies([sgd_step_op, coldOptim_op]):
|
|
sgd_step_op = tf.Print(
|
|
sgd_step_op, [self.sgd_step, tf.convert_to_tensor('doing cold sgd step')])
|
|
return tf.group(*[sgd_step_op, coldOptim_op])
|
|
|
|
kfacOptim_op, qr = self.apply_gradients_kfac(grads)
|
|
|
|
def warmKFACstart():
|
|
return kfacOptim_op
|
|
|
|
return tf.cond(tf.greater(self.sgd_step, self._cold_iter), warmKFACstart, coldSGDstart), qr
|
|
|
|
def minimize(self, loss, loss_sampled, var_list=None):
|
|
grads = self.compute_gradients(loss, var_list=var_list)
|
|
update_stats_op = self.compute_and_apply_stats(
|
|
loss_sampled, var_list=var_list)
|
|
return self.apply_gradients(grads)
|