Jax Wavelet Toolbox (jaxwt)¶
Differentiable and GPU-enabled fast wavelet transforms in JAX.
Features¶
1d analysis and synthesis transforms are implemented in
src/jaxwt/conv_fwt.py. Trywavedecandwaverec.2d analysis and synthesis transforms are part of the
src/jaxwt/conv_fwt_2d.pymodule. The two functions are calledwavedec2andwaverec2.Furthermore, 3d transforms are provided by the
wavedec3andwaverec3functions.cwt-function supports 1d continuous wavelet transforms.The
WaveletPacketobject supports 1d wavelet packet transforms.WaveletPacket2dimplements two-dimensional wavelet packet transforms.swtcomputes a single dimensional stationary transformiswtinverts it.
This toolbox extends PyWavelets .
jaxwt additionally provides GPU and gradient support via a Jax backend.
Installation¶
To install Jax, head over to https://github.com/google/jax#installation and follow the procedure described there.
Afterward, type pip install jaxwt to install the Jax-Wavelet-Toolbox. You can uninstall it later by typing pip uninstall jaxwt.
Documentation¶
The documentation is available at: https://jax-wavelet-toolbox.readthedocs.io/en/latest/jaxwt.html .
Transform Examples:¶
One-dimensional fast wavelet transform:
import pywt
import numpy as np;
import jax.numpy as jnp
import jaxwt as jwt
# generate an input of even length.
data = jnp.array([0., 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 2, 1, 0])
wavelet = pywt.Wavelet('haar')
# compare the forward fwt coefficients
print(pywt.wavedec(np.array(data), wavelet, mode='zero', level=2))
print(jwt.wavedec(data, wavelet, mode='zero', level=2))
# invert the fwt.
print(jwt.waverec(jwt.wavedec(data, wavelet, mode='zero', level=2),
wavelet))
Two-dimensional fast wavelet transform:
import pywt, scipy.datasets
import jaxwt as jwt
import jax.numpy as jnp
face = jnp.transpose(
scipy.datasets.face(), [2, 0, 1]).astype(jnp.float64)
transformed = jwt.wavedec2(face, pywt.Wavelet("haar"),
level=2, mode="reflect")
reconstruction = jwt.waverec2(transformed, pywt.Wavelet("haar"))
jnp.max(jnp.abs(face - reconstruction))
Testing¶
Unit tests are handled by nox. Clone the repository and run it with the following:
$ pip install nox
$ git clone https://github.com/v0lta/Jax-Wavelet-Toolbox
$ cd Jax-Wavelet-Toolbox
$ nox -s test
Goals¶
In the spirit of Jax, the aim is to be 100% pywt compatible. Whenever possible, interfaces should be the same results identical.
64-Bit floating-point numbers¶
If you need 64-bit floating point support, set the Jax config flag:
from jax.config import config
config.update("jax_enable_x64", True)
Citation¶
If you use this work in a scientific context, please cite:
@phdthesis{handle:20.500.11811/9245,
urn: https://nbn-resolving.org/urn:nbn:de:hbz:5-63361,
author = {{Moritz Wolter}},
title = {Frequency Domain Methods in Recurrent Neural Networks for Sequential Data Processing},
school = {Rheinische Friedrich-Wilhelms-Universität Bonn},
year = 2021,
month = jul,
url = {https://hdl.handle.net/20.500.11811/9245}
}
Use any of the jaxwt package links below to directly jump into the documentation.
jaxwt module overview: