Parameter space matrix representations¶
All parameter space matrix representations inherit from nngeometry.object.pspace.PMatAbstract
. This abstract class defines all method that can be used with all representations (with some exceptions!). nngeometry.object.pspace.PMatAbstract
cannot be instantiated, you instead have to choose one of the concrete representations below.
- class nngeometry.object.pspace.PMatAbstract(generator, data=None, examples=None)¶
A \(d \times d\) matrix in parameter space. This abstract class defines common methods used in concrete representations.
- Parameters
generator (
nngeometry.generator.jacobian.Jacobian
) – The generatordata – if None, it requires examples to be different from None, and it uses the generator to populate the matrix data
examples – if data is None, it uses these examples to populate the matrix using the generator. examples is either a Dataloader, or a single mini-batch of (inputs, targets) from a Dataloader
Note
Either data or examples has to be different from None, and both cannot be not None at the same time.
- abstract get_diag()¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- size(dim=None)¶
Size of the matrix as a tuple, regardless of the actual size in memory.
- Parameters
dim (int or None) – dimension
>>> M.size() (1254, 1254) >>> M.size(0) 1254
- abstract solve(v, regul)¶
Solves Fx = v in x
- Parameters
regul (PVector) – Tikhonov regularization
v – v
- abstract vTMv(v)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
Concrete representations¶
NNGeometry allows to switch between representations easily. With each representation comes a tradeof between accuracy and memory/computational cost. If testing a new algorithm, we recommend testing on a small network using the most accurate representation that fits in memory (typically nngeometry.object.pspace.PMatDense
), then switch to a larger scale experiment, and to a lower memory representation.
- class nngeometry.object.pspace.PMatBlockDiag(generator, data=None, examples=None)¶
- get_diag()¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- mm(other)¶
Matrix-matrix product where other is another instance of PMatBlockDiag
- Parameters
other (
nngeometry.object.PMatBlockDiag
) – Other FIM matrix- Returns
The matrix-matrix product
- Return type
nngeometry.object.PMatBlockDiag
- solve(vs, regul=1e-08)¶
Solves Fx = v in x
- Parameters
regul (PVector) – Tikhonov regularization
v – v
- vTMv(vector)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- class nngeometry.object.pspace.PMatDense(generator, data=None, examples=None)¶
- get_diag()¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- mm(other)¶
Matrix-matrix product where other is another instance of PMatDense
- Parameters
other (
nngeometry.object.PMatDense
) – Other FIM matrix- Returns
The matrix-matrix product
- Return type
nngeometry.object.PMatDense
- solve(v, regul=1e-08, impl='solve')¶
solves v = Ax in x
- vTMv(v)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- class nngeometry.object.pspace.PMatDiag(generator, data=None, examples=None)¶
- get_diag()¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- mm(other)¶
Matrix-matrix product where other is another instance of PMatDiag
- Parameters
other (
nngeometry.object.PMatDiag
) – Other FIM matrix- Returns
The matrix-matrix product
- Return type
nngeometry.object.PMatDiag
- solve(v, regul=1e-08)¶
solves v = Ax in x
- vTMv(v)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- class nngeometry.object.pspace.PMatEKFAC(generator, data=None, examples=None)¶
EKFAC representation from George, Laurent et al., Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis, NIPS 2018
- get_KFE(split_weight_bias=True)¶
Returns a dict index by layers, of dense eigenvectors constructed from Kronecker-factored eigenvectors
split_weight_bias (bool): if True then the parameters are ordered in
the same way as in the dense or blockdiag representation, but it involves more operations. Otherwise the coefficients corresponding to the bias are mixed between coefficients of the weight matrix
- get_dense_tensor(split_weight_bias=True)¶
split_weight_bias (bool): if True then the parameters are ordered in
the same way as in the dense or blockdiag representation, but it involves more operations. Otherwise the coefficients corresponding to the bias are mixed between coefficients of the weight matrix
- get_diag(v)¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- solve(vs, regul=1e-08)¶
Solves Fx = v in x
- Parameters
regul (PVector) – Tikhonov regularization
v – v
- update_diag(examples)¶
Will update the diagonal in the KFE (aka the approximate eigenvalues) using current values of the model’s parameters
- vTMv(vector)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- class nngeometry.object.pspace.PMatImplicit(generator, data=None, examples=None)¶
PMatImplicit is a very special representation, since no elements of the matrix is ever computed, but instead various linear algebra operations are performed implicitely using efficient tricks.
The computations are done exactly, meaning that there is no approximation involved. This is useful for networks too big to fit in memory.
- get_diag()¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- vTMv(v)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- class nngeometry.object.pspace.PMatKFAC(generator, data=None, examples=None)¶
- get_dense_tensor(split_weight_bias=True)¶
split_weight_bias (bool): if True then the parameters are ordered in
the same way as in the dense or blockdiag representation, but it involves more operations. Otherwise the coefficients corresponding to the bias are mixed between coefficients of the weight matrix
- get_diag(split_weight_bias=True)¶
split_weight_bias (bool): if True then the parameters are ordered in
the same way as in the dense or blockdiag representation, but it involves more operations. Otherwise the coefficients corresponding to the bias are mixed between coefficients of the weight matrix
- mm(other)¶
Matrix-matrix product where other is another instance of PMatKFAC
- Parameters
other (
nngeometry.object.PMatKFAC
) – Other FIM matrix- Returns
The matrix-matrix product
- Return type
nngeometry.object.PMatKFAC
- solve(vs, regul=1e-08, use_pi=True)¶
Solves Fx = v in x
- Parameters
regul (PVector) – Tikhonov regularization
v – v
- vTMv(vector)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- class nngeometry.object.pspace.PMatLowRank(generator, data=None, examples=None)¶
- get_diag()¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- solve(b, regul=1e-08)¶
Solves Fx = v in x
- Parameters
regul (PVector) – Tikhonov regularization
v – v
- vTMv(v)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- class nngeometry.object.pspace.PMatQuasiDiag(generator, data=None, examples=None)¶
Quasidiagonal approximation as decribed in Ollivier, Riemannian metrics for neural networks I: feedforward networks, Information and Inference: A Journal of the IMA, 2015
- get_diag()¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- solve(vs, regul=1e-08)¶
Solves Fx = v in x
- Parameters
regul (PVector) – Tikhonov regularization
v – v
- vTMv(vs)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- nngeometry.object.pspace.bdot(A, B)¶
batched dot product