Parameter space matrix representations

All parameter space matrix representations inherit from nngeometry.object.pspace.PMatAbstract. This abstract class defines all method that can be used with all representations (with some exceptions!). nngeometry.object.pspace.PMatAbstract cannot be instantiated, you instead have to choose one of the concrete representations below.

class nngeometry.object.pspace.PMatAbstract(generator, data=None, examples=None)

A \(d \times d\) matrix in parameter space. This abstract class defines common methods used in concrete representations.

Parameters
  • generator (nngeometry.generator.jacobian.Jacobian) – The generator

  • data – if None, it requires examples to be different from None, and it uses the generator to populate the matrix data

  • examples – if data is None, it uses these examples to populate the matrix using the generator. examples is either a Dataloader, or a single mini-batch of (inputs, targets) from a Dataloader

Note

Either data or examples has to be different from None, and both cannot be not None at the same time.

abstract get_diag()

Computes and returns the diagonal elements of this matrix.

Returns

a PyTorch Tensor

size(dim=None)

Size of the matrix as a tuple, regardless of the actual size in memory.

Parameters

dim (int or None) – dimension

>>> M.size()
(1254, 1254)
>>> M.size(0)
1254
abstract solve(v, regul)

Solves Fx = v in x

Parameters
  • regul (PVector) – Tikhonov regularization

  • v – v

abstract vTMv(v)

Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)

Parameters

v (object.vector.PVector) – vector \(v\)

Concrete representations

NNGeometry allows to switch between representations easily. With each representation comes a tradeof between accuracy and memory/computational cost. If testing a new algorithm, we recommend testing on a small network using the most accurate representation that fits in memory (typically nngeometry.object.pspace.PMatDense), then switch to a larger scale experiment, and to a lower memory representation.

class nngeometry.object.pspace.PMatBlockDiag(generator, data=None, examples=None)
get_diag()

Computes and returns the diagonal elements of this matrix.

Returns

a PyTorch Tensor

mm(other)

Matrix-matrix product where other is another instance of PMatBlockDiag

Parameters

other (nngeometry.object.PMatBlockDiag) – Other FIM matrix

Returns

The matrix-matrix product

Return type

nngeometry.object.PMatBlockDiag

solve(vs, regul=1e-08)

Solves Fx = v in x

Parameters
  • regul (PVector) – Tikhonov regularization

  • v – v

vTMv(vector)

Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)

Parameters

v (object.vector.PVector) – vector \(v\)

class nngeometry.object.pspace.PMatDense(generator, data=None, examples=None)
get_diag()

Computes and returns the diagonal elements of this matrix.

Returns

a PyTorch Tensor

mm(other)

Matrix-matrix product where other is another instance of PMatDense

Parameters

other (nngeometry.object.PMatDense) – Other FIM matrix

Returns

The matrix-matrix product

Return type

nngeometry.object.PMatDense

solve(v, regul=1e-08, impl='solve')

solves v = Ax in x

vTMv(v)

Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)

Parameters

v (object.vector.PVector) – vector \(v\)

class nngeometry.object.pspace.PMatDiag(generator, data=None, examples=None)
get_diag()

Computes and returns the diagonal elements of this matrix.

Returns

a PyTorch Tensor

mm(other)

Matrix-matrix product where other is another instance of PMatDiag

Parameters

other (nngeometry.object.PMatDiag) – Other FIM matrix

Returns

The matrix-matrix product

Return type

nngeometry.object.PMatDiag

solve(v, regul=1e-08)

solves v = Ax in x

vTMv(v)

Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)

Parameters

v (object.vector.PVector) – vector \(v\)

class nngeometry.object.pspace.PMatEKFAC(generator, data=None, examples=None)

EKFAC representation from George, Laurent et al., Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis, NIPS 2018

get_KFE(split_weight_bias=True)

Returns a dict index by layers, of dense eigenvectors constructed from Kronecker-factored eigenvectors

  • split_weight_bias (bool): if True then the parameters are ordered in

the same way as in the dense or blockdiag representation, but it involves more operations. Otherwise the coefficients corresponding to the bias are mixed between coefficients of the weight matrix

get_dense_tensor(split_weight_bias=True)
  • split_weight_bias (bool): if True then the parameters are ordered in

the same way as in the dense or blockdiag representation, but it involves more operations. Otherwise the coefficients corresponding to the bias are mixed between coefficients of the weight matrix

get_diag(v)

Computes and returns the diagonal elements of this matrix.

Returns

a PyTorch Tensor

solve(vs, regul=1e-08)

Solves Fx = v in x

Parameters
  • regul (PVector) – Tikhonov regularization

  • v – v

update_diag(examples)

Will update the diagonal in the KFE (aka the approximate eigenvalues) using current values of the model’s parameters

vTMv(vector)

Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)

Parameters

v (object.vector.PVector) – vector \(v\)

class nngeometry.object.pspace.PMatImplicit(generator, data=None, examples=None)

PMatImplicit is a very special representation, since no elements of the matrix is ever computed, but instead various linear algebra operations are performed implicitely using efficient tricks.

The computations are done exactly, meaning that there is no approximation involved. This is useful for networks too big to fit in memory.

get_diag()

Computes and returns the diagonal elements of this matrix.

Returns

a PyTorch Tensor

solve(v)

Solves Fx = v in x

Parameters
  • regul (PVector) – Tikhonov regularization

  • v – v

vTMv(v)

Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)

Parameters

v (object.vector.PVector) – vector \(v\)

class nngeometry.object.pspace.PMatKFAC(generator, data=None, examples=None)
get_dense_tensor(split_weight_bias=True)
  • split_weight_bias (bool): if True then the parameters are ordered in

the same way as in the dense or blockdiag representation, but it involves more operations. Otherwise the coefficients corresponding to the bias are mixed between coefficients of the weight matrix

get_diag(split_weight_bias=True)
  • split_weight_bias (bool): if True then the parameters are ordered in

the same way as in the dense or blockdiag representation, but it involves more operations. Otherwise the coefficients corresponding to the bias are mixed between coefficients of the weight matrix

mm(other)

Matrix-matrix product where other is another instance of PMatKFAC

Parameters

other (nngeometry.object.PMatKFAC) – Other FIM matrix

Returns

The matrix-matrix product

Return type

nngeometry.object.PMatKFAC

solve(vs, regul=1e-08, use_pi=True)

Solves Fx = v in x

Parameters
  • regul (PVector) – Tikhonov regularization

  • v – v

vTMv(vector)

Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)

Parameters

v (object.vector.PVector) – vector \(v\)

class nngeometry.object.pspace.PMatLowRank(generator, data=None, examples=None)
get_diag()

Computes and returns the diagonal elements of this matrix.

Returns

a PyTorch Tensor

solve(b, regul=1e-08)

Solves Fx = v in x

Parameters
  • regul (PVector) – Tikhonov regularization

  • v – v

vTMv(v)

Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)

Parameters

v (object.vector.PVector) – vector \(v\)

class nngeometry.object.pspace.PMatQuasiDiag(generator, data=None, examples=None)

Quasidiagonal approximation as decribed in Ollivier, Riemannian metrics for neural networks I: feedforward networks, Information and Inference: A Journal of the IMA, 2015

get_diag()

Computes and returns the diagonal elements of this matrix.

Returns

a PyTorch Tensor

solve(vs, regul=1e-08)

Solves Fx = v in x

Parameters
  • regul (PVector) – Tikhonov regularization

  • v – v

vTMv(vs)

Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)

Parameters

v (object.vector.PVector) – vector \(v\)

nngeometry.object.pspace.bdot(A, B)

batched dot product