This package offers einsum
-based implementations of convolutions and related
operations in PyTorch.
Its name is inspired by this Github repository which represented the starting point for our work.
Install from PyPI via pip
pip install einconv
-
For more tutorials, check out the docs
In general, einconv
's goals are:
- Full hyper-parameter support (stride, padding, dilation, groups, etc.)
- Support for any dimension (e.g. 5d-convolution)
- Optimizations via symbolic simplification
einconv
provides einsum
-based implementations of the following PyTorch modules:
torch module |
einconv module |
---|---|
nn.Conv{1,2,3}d |
modules.ConvNd |
nn.Unfold |
modules.UnfoldNd |
They work in exactly the same way as their PyTorch equivalents.
einconv
provides einsum
-based implementations of the following PyTorch functionals:
torch functional |
einconv functional |
---|---|
nn.functional.conv{1,2,3}d |
functionals.convNd |
nn.functional.unfold |
functionals.unfoldNd |
They work in exactly the same way as their PyTorch equivalents.
einconv
can generate einsum
expressions (equation, operands, and output
shape) for the following operations:
- Forward pass of
N
-dimensional convolution - Backward pass (input and weight VJPs) of
N
-dimensional convolution - Input unfolding (
im2col/unfold
) for inputs ofN
-dimensional convolution - Input-based Kronecker factors of Fisher approximations for convolutions (KFC and KFAC-reduce)
These can then be evaluated with einsum
. For instance, the einsum
expression
for the forward pass of an N
-dimensional convolution is
from torch import einsum
from einconv.expressions import convNd_forward
equation, operands, shape = convNd_forward.einsum_expression(...)
result = einsum(equation, *operands).reshape(shape)
All expressions follow this pattern.
Some operations (e.g. dense convolutions) can be optimized via symbolic simplifications. This is turned on by default as it generally improves performance. You can also generate a non-optimized expression and simplify it:
from einconv import simplify
equation, operands, shape = convNd_forward.einsum_expression(..., simplify=False)
equation, operands = simplify(equation, operands)
result = einsum(equation, *operands).reshape(shape)
Sometimes it might be better to inspect the non-simplified expression to see how indices relate to operands.
If you find the einconv
package useful for your research, consider mentioning
the accompanying article
@article{dangel2023convolutions,
title = {Convolutions Through the Lens of Tensor Networks},
author = {Dangel, Felix},
year = 2023,
}
-
Currently, none of the underlying operations (computation of index pattern tensors, generation of einsum equations and shapes, simplification) is cached. This consumes additional time, although it should usually take much less time than evaluating an expression via
einsum
. -
At the moment, the code to perform expression simplifications is coupled with PyTorch. I am planning to address this in the future by switching the implementation to a symbolic approach which will also allow efficient caching.