import torch
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
[docs]class KFAC(Optimizer):
def __init__(self, net, eps, sua=False, pi=False, update_freq=1,
alpha=1.0, constraint_norm=False):
""" K-FAC Preconditionner for Linear and Conv2d layers.
Computes the K-FAC of the second moment of the gradients.
It works for Linear and Conv2d layers and silently skip other layers.
Args:
net (torch.nn.Module): Network to precondition.
eps (float): Tikhonov regularization parameter for the inverses.
sua (bool): Applies SUA approximation.
pi (bool): Computes pi correction for Tikhonov regularization.
update_freq (int): Perform inverses every update_freq updates.
alpha (float): Running average parameter (if == 1, no r. ave.).
constraint_norm (bool): Scale the gradients by the squared
fisher norm.
"""
self.eps = eps
self.sua = sua
self.pi = pi
self.update_freq = update_freq
self.alpha = alpha
self.constraint_norm = constraint_norm
self.params = []
self._fwd_handles = []
self._bwd_handles = []
self._iteration_counter = 0
for mod in net.modules():
mod_class = mod.__class__.__name__
if mod_class in ['Linear', 'Conv2d']:
handle = mod.register_forward_pre_hook(self._save_input)
self._fwd_handles.append(handle)
handle = mod.register_full_backward_hook(self._save_grad_output)
self._bwd_handles.append(handle)
params = [mod.weight]
if mod.bias is not None:
params.append(mod.bias)
d = {'params': params, 'mod': mod, 'layer_type': mod_class}
self.params.append(d)
super(KFAC, self).__init__(self.params, {})
[docs] def step(self, update_stats=True, update_params=True):
"""Performs one step of preconditioning."""
fisher_norm = 0.
for group in self.param_groups:
# Getting parameters
if len(group['params']) == 2:
weight, bias = group['params']
else:
weight = group['params'][0]
bias = None
state = self.state[weight]
# Update convariances and inverses
if update_stats:
if self._iteration_counter % self.update_freq == 0:
self._compute_covs(group, state)
ixxt, iggt = self._inv_covs(state['xxt'], state['ggt'],
state['num_locations'])
state['ixxt'] = ixxt
state['iggt'] = iggt
else:
if self.alpha != 1:
self._compute_covs(group, state)
if update_params:
# Preconditionning
gw, gb = self._precond(weight, bias, group, state)
# Updating gradients
if self.constraint_norm:
fisher_norm += (weight.grad * gw).sum()
weight.grad.data = gw
if bias is not None:
if self.constraint_norm:
fisher_norm += (bias.grad * gb).sum()
bias.grad.data = gb
# Cleaning
if 'x' in self.state[group['mod']]:
del self.state[group['mod']]['x']
if 'gy' in self.state[group['mod']]:
del self.state[group['mod']]['gy']
# Eventually scale the norm of the gradients
if update_params and self.constraint_norm:
scale = (1. / fisher_norm) ** 0.5
for group in self.param_groups:
for param in group['params']:
param.grad.data *= scale
if update_stats:
self._iteration_counter += 1
def _save_input(self, mod, i):
"""Saves input of layer to compute covariance."""
if mod.training:
self.state[mod]['x'] = i[0]
def _save_grad_output(self, mod, grad_input, grad_output):
"""Saves grad on output of layer to compute covariance."""
if mod.training:
self.state[mod]['gy'] = grad_output[0] * grad_output[0].size(0)
def _precond(self, weight, bias, group, state):
"""Applies preconditioning."""
if group['layer_type'] == 'Conv2d' and self.sua:
return self._precond_sua(weight, bias, group, state)
ixxt = state['ixxt']
iggt = state['iggt']
g = weight.grad.data
s = g.shape
if group['layer_type'] == 'Conv2d':
g = g.contiguous().view(s[0], s[1]*s[2]*s[3])
if bias is not None:
gb = bias.grad.data
g = torch.cat([g, gb.view(gb.shape[0], 1)], dim=1)
g = torch.mm(torch.mm(iggt, g), ixxt)
if group['layer_type'] == 'Conv2d':
g /= state['num_locations']
if bias is not None:
gb = g[:, -1].contiguous().view(*bias.shape)
g = g[:, :-1]
else:
gb = None
g = g.contiguous().view(*s)
return g, gb
def _precond_sua(self, weight, bias, group, state):
"""Preconditioning for KFAC SUA."""
ixxt = state['ixxt']
iggt = state['iggt']
g = weight.grad.data
s = g.shape
mod = group['mod']
g = g.permute(1, 0, 2, 3).contiguous()
if bias is not None:
gb = bias.grad.view(1, -1, 1, 1).expand(1, -1, s[2], s[3])
g = torch.cat([g, gb], dim=0)
g = torch.mm(ixxt, g.contiguous().view(-1, s[0]*s[2]*s[3]))
g = g.view(-1, s[0], s[2], s[3]).permute(1, 0, 2, 3).contiguous()
g = torch.mm(iggt, g.view(s[0], -1)).view(s[0], -1, s[2], s[3])
g /= state['num_locations']
if bias is not None:
gb = g[:, -1, s[2]//2, s[3]//2]
g = g[:, :-1]
else:
gb = None
return g, gb
def _compute_covs(self, group, state):
"""Computes the covariances."""
mod = group['mod']
x = self.state[group['mod']]['x']
gy = self.state[group['mod']]['gy']
# Computation of xxt
if group['layer_type'] == 'Conv2d':
if not self.sua:
x = F.unfold(x, mod.kernel_size, padding=mod.padding,
stride=mod.stride)
else:
x = x.view(x.shape[0], x.shape[1], -1)
x = x.data.permute(1, 0, 2).contiguous().view(x.shape[1], -1)
else:
x = x.data.t()
if mod.bias is not None:
ones = torch.ones_like(x[:1])
x = torch.cat([x, ones], dim=0)
if self._iteration_counter == 0:
state['xxt'] = torch.mm(x, x.t()) / float(x.shape[1])
else:
state['xxt'].addmm_(mat1=x, mat2=x.t(),
beta=(1. - self.alpha),
alpha=self.alpha / float(x.shape[1]))
# Computation of ggt
if group['layer_type'] == 'Conv2d':
gy = gy.data.permute(1, 0, 2, 3)
state['num_locations'] = gy.shape[2] * gy.shape[3]
gy = gy.contiguous().view(gy.shape[0], -1)
else:
gy = gy.data.t()
state['num_locations'] = 1
if self._iteration_counter == 0:
state['ggt'] = torch.mm(gy, gy.t()) / float(gy.shape[1])
else:
state['ggt'].addmm_(mat1=gy, mat2=gy.t(),
beta=(1. - self.alpha),
alpha=self.alpha / float(gy.shape[1]))
def _inv_covs(self, xxt, ggt, num_locations):
"""Inverses the covariances."""
# Computes pi
pi = 1.0
if self.pi:
tx = torch.trace(xxt) * ggt.shape[0]
tg = torch.trace(ggt) * xxt.shape[0]
pi = (tx / tg)
# Regularizes and inverse
eps = self.eps / num_locations
diag_xxt = xxt.new(xxt.shape[0]).fill_((eps * pi) ** 0.5)
diag_ggt = ggt.new(ggt.shape[0]).fill_((eps / pi) ** 0.5)
ixxt = (xxt + torch.diag(diag_xxt)).inverse()
iggt = (ggt + torch.diag(diag_ggt)).inverse()
return ixxt, iggt
def __del__(self):
for handle in self._fwd_handles + self._bwd_handles:
handle.remove()