Source code for plugins.natgrad.kfac

import torch
import torch.nn.functional as F

from torch.optim.optimizer import Optimizer


[docs]class KFAC(Optimizer): def __init__(self, net, eps, sua=False, pi=False, update_freq=1, alpha=1.0, constraint_norm=False): """ K-FAC Preconditionner for Linear and Conv2d layers. Computes the K-FAC of the second moment of the gradients. It works for Linear and Conv2d layers and silently skip other layers. Args: net (torch.nn.Module): Network to precondition. eps (float): Tikhonov regularization parameter for the inverses. sua (bool): Applies SUA approximation. pi (bool): Computes pi correction for Tikhonov regularization. update_freq (int): Perform inverses every update_freq updates. alpha (float): Running average parameter (if == 1, no r. ave.). constraint_norm (bool): Scale the gradients by the squared fisher norm. """ self.eps = eps self.sua = sua self.pi = pi self.update_freq = update_freq self.alpha = alpha self.constraint_norm = constraint_norm self.params = [] self._fwd_handles = [] self._bwd_handles = [] self._iteration_counter = 0 for mod in net.modules(): mod_class = mod.__class__.__name__ if mod_class in ['Linear', 'Conv2d']: handle = mod.register_forward_pre_hook(self._save_input) self._fwd_handles.append(handle) handle = mod.register_full_backward_hook(self._save_grad_output) self._bwd_handles.append(handle) params = [mod.weight] if mod.bias is not None: params.append(mod.bias) d = {'params': params, 'mod': mod, 'layer_type': mod_class} self.params.append(d) super(KFAC, self).__init__(self.params, {})
[docs] def step(self, update_stats=True, update_params=True): """Performs one step of preconditioning.""" fisher_norm = 0. for group in self.param_groups: # Getting parameters if len(group['params']) == 2: weight, bias = group['params'] else: weight = group['params'][0] bias = None state = self.state[weight] # Update convariances and inverses if update_stats: if self._iteration_counter % self.update_freq == 0: self._compute_covs(group, state) ixxt, iggt = self._inv_covs(state['xxt'], state['ggt'], state['num_locations']) state['ixxt'] = ixxt state['iggt'] = iggt else: if self.alpha != 1: self._compute_covs(group, state) if update_params: # Preconditionning gw, gb = self._precond(weight, bias, group, state) # Updating gradients if self.constraint_norm: fisher_norm += (weight.grad * gw).sum() weight.grad.data = gw if bias is not None: if self.constraint_norm: fisher_norm += (bias.grad * gb).sum() bias.grad.data = gb # Cleaning if 'x' in self.state[group['mod']]: del self.state[group['mod']]['x'] if 'gy' in self.state[group['mod']]: del self.state[group['mod']]['gy'] # Eventually scale the norm of the gradients if update_params and self.constraint_norm: scale = (1. / fisher_norm) ** 0.5 for group in self.param_groups: for param in group['params']: param.grad.data *= scale if update_stats: self._iteration_counter += 1
def _save_input(self, mod, i): """Saves input of layer to compute covariance.""" if mod.training: self.state[mod]['x'] = i[0] def _save_grad_output(self, mod, grad_input, grad_output): """Saves grad on output of layer to compute covariance.""" if mod.training: self.state[mod]['gy'] = grad_output[0] * grad_output[0].size(0) def _precond(self, weight, bias, group, state): """Applies preconditioning.""" if group['layer_type'] == 'Conv2d' and self.sua: return self._precond_sua(weight, bias, group, state) ixxt = state['ixxt'] iggt = state['iggt'] g = weight.grad.data s = g.shape if group['layer_type'] == 'Conv2d': g = g.contiguous().view(s[0], s[1]*s[2]*s[3]) if bias is not None: gb = bias.grad.data g = torch.cat([g, gb.view(gb.shape[0], 1)], dim=1) g = torch.mm(torch.mm(iggt, g), ixxt) if group['layer_type'] == 'Conv2d': g /= state['num_locations'] if bias is not None: gb = g[:, -1].contiguous().view(*bias.shape) g = g[:, :-1] else: gb = None g = g.contiguous().view(*s) return g, gb def _precond_sua(self, weight, bias, group, state): """Preconditioning for KFAC SUA.""" ixxt = state['ixxt'] iggt = state['iggt'] g = weight.grad.data s = g.shape mod = group['mod'] g = g.permute(1, 0, 2, 3).contiguous() if bias is not None: gb = bias.grad.view(1, -1, 1, 1).expand(1, -1, s[2], s[3]) g = torch.cat([g, gb], dim=0) g = torch.mm(ixxt, g.contiguous().view(-1, s[0]*s[2]*s[3])) g = g.view(-1, s[0], s[2], s[3]).permute(1, 0, 2, 3).contiguous() g = torch.mm(iggt, g.view(s[0], -1)).view(s[0], -1, s[2], s[3]) g /= state['num_locations'] if bias is not None: gb = g[:, -1, s[2]//2, s[3]//2] g = g[:, :-1] else: gb = None return g, gb def _compute_covs(self, group, state): """Computes the covariances.""" mod = group['mod'] x = self.state[group['mod']]['x'] gy = self.state[group['mod']]['gy'] # Computation of xxt if group['layer_type'] == 'Conv2d': if not self.sua: x = F.unfold(x, mod.kernel_size, padding=mod.padding, stride=mod.stride) else: x = x.view(x.shape[0], x.shape[1], -1) x = x.data.permute(1, 0, 2).contiguous().view(x.shape[1], -1) else: x = x.data.t() if mod.bias is not None: ones = torch.ones_like(x[:1]) x = torch.cat([x, ones], dim=0) if self._iteration_counter == 0: state['xxt'] = torch.mm(x, x.t()) / float(x.shape[1]) else: state['xxt'].addmm_(mat1=x, mat2=x.t(), beta=(1. - self.alpha), alpha=self.alpha / float(x.shape[1])) # Computation of ggt if group['layer_type'] == 'Conv2d': gy = gy.data.permute(1, 0, 2, 3) state['num_locations'] = gy.shape[2] * gy.shape[3] gy = gy.contiguous().view(gy.shape[0], -1) else: gy = gy.data.t() state['num_locations'] = 1 if self._iteration_counter == 0: state['ggt'] = torch.mm(gy, gy.t()) / float(gy.shape[1]) else: state['ggt'].addmm_(mat1=gy, mat2=gy.t(), beta=(1. - self.alpha), alpha=self.alpha / float(gy.shape[1])) def _inv_covs(self, xxt, ggt, num_locations): """Inverses the covariances.""" # Computes pi pi = 1.0 if self.pi: tx = torch.trace(xxt) * ggt.shape[0] tg = torch.trace(ggt) * xxt.shape[0] pi = (tx / tg) # Regularizes and inverse eps = self.eps / num_locations diag_xxt = xxt.new(xxt.shape[0]).fill_((eps * pi) ** 0.5) diag_ggt = ggt.new(ggt.shape[0]).fill_((eps / pi) ** 0.5) ixxt = (xxt + torch.diag(diag_xxt)).inverse() iggt = (ggt + torch.diag(diag_ggt)).inverse() return ixxt, iggt def __del__(self): for handle in self._fwd_handles + self._bwd_handles: handle.remove()