jaxwt package

Submodules

jaxwt.conv_fwt module

Convolution based fast wavelet transforms.

jaxwt.conv_fwt.wavedec(data: Array, wavelet: Wavelet, level: Optional[int] = None, mode: str = 'reflect') List[Array][source]

Compute the one dimensional analysis wavelet transform of the last dimension.

Parameters
  • data (jnp.array) – Input data array of shape [batch, channels, time]

  • wavelet (Wavelet) – The named tuple containing the wavelet filter arrays.

  • level (int) – Max scale level to be used, of none as many levels as possible are used. Defaults to None.

  • mode – The padding used to extend the input signal. Choose reflect or symmetric. Defaults to reflect.

Returns

List containing the wavelet coefficients.

The coefficients are in pywt order: [cA_n, cD_n, cD_n-1, …, cD2, cD1]. A denotes approximation and D detail coefficients.

Return type

list

Examples

>>> import pywt
>>> import jaxwt as jwt
>>> import jax.numpy as np
>>> # generate an input of even length.
>>> data = jnp.array([0., 1., 2., 3, 4, 5, 5, 4, 3, 2, 1, 0])
>>> jwt.wavedec(data, pywt.Wavelet('haar'),
                mode='reflect', level=2)
jaxwt.conv_fwt.waverec(coeffs: List[Array], wavelet: Wavelet) Array[source]

Reconstruct the original signal in one dimension.

Parameters
  • coeffs (list) – Wavelet coefficients, typically produced by the wavedec function.

  • wavelet (Wavelet) – The named tuple containing the wavelet filters used to evaluate the decomposition.

Returns

Reconstruction of the original data.

Return type

jnp.array

Examples

>>> import pywt
>>> import jaxwt as jwt
>>> import jax.numpy as np
>>> # generate an input of even length.
>>> data = jnp.array([0., 1., 2., 3, 4, 5, 5, 4, 3, 2, 1, 0])
>>> transformed = jwt.wavedec(data, pywt.Wavelet('haar'),
                  mode='reflect', level=2)
>>> jwt.waverec(transformed, pywt.Wavelet('haar'))

jaxwt.conv_fwt_2d module

Two dimensional convolution based fast wavelet transforms.

jaxwt.conv_fwt_2d.construct_2d_filt(lo: Array, hi: Array) Array[source]

Construct 2d filters from 1d inputs using outer products.

Parameters
  • lo (jnp.array) – 1d lowpass input filter of size [1, length].

  • hi (jnp.array) – 1d highpass input filter of size [1, length].

Returns

jnp.array: 2d filter arrays of shape [4, 1, length, length].

jaxwt.conv_fwt_2d.wavedec2(data: Array, wavelet: Wavelet, level: Optional[int] = None, mode: str = 'reflect') List[Union[Array, Tuple[Array, Array, Array]]][source]

Compute the two dimensional wavelet analysis transform on the last two dimensions of the input data array.

Parameters
  • data (jnp.array) – Jax array containing the data to be transformed. Assumed shape: [batch size, hight, width].

  • wavelet (Wavelet) – A namedtouple containing the filters for the transformation.

  • level (int) – The max level to be used, if not set as many levels as possible will be used. Defaults to None.

  • mode (str) – The desired padding mode. Choose reflect or symmetric. Defaults to reflect.

Returns

The wavelet coefficients in a nested list.

The coefficients are in pywt order. That is: [cAn, (cHn, cVn, cDn), … (cH1, cV1, cD1)]. A denotes approximation, H horizontal, V vertical and D diagonal coefficients.

Return type

list

Examples

>>> import pywt, scipy.misc
>>> import jaxwt as jwt
>>> import jax.numpy as jnp
>>> face = jnp.transpose(scipy.misc.face(), [2, 0, 1]).astype(jnp.float64)
>>> jwt.wavedec2(face, pywt.Wavelet("haar"), level=2, mode="reflect")
jaxwt.conv_fwt_2d.waverec2(coeffs: List[Union[Array, Tuple[Array, Array, Array]]], wavelet: Wavelet) Array[source]

Compute a two dimensional synthesis wavelet transfrom.

Use it to reconstruct the original input image from the wavelet coefficients.

Parameters
  • coeffs (list) – The input coefficients, typically the output of wavedec2.

  • wavelet (Wavelet) – The named tuple contining the filters used to compute the analysis transform.

Returns

Reconstruction of the original input data array of shape [batch, height, width].

Return type

jnp.array

Raises

ValueError – If coeffs is not in the shape as it is returned from wavedec2.

Example

>>> import pywt, scipy.misc
>>> import jaxwt as jwt
>>> import jax.numpy as jnp
>>> face = jnp.transpose(scipy.misc.face(), [2, 0, 1]).astype(jnp.float64)
>>> transformed = jwt.wavedec2(face, pywt.Wavelet("haar"), level=2, mode="reflect")
>>> jwt.waverec2(transformed, pywt.Wavelet("haar"))

jaxwt.continuous_transform

Jax compatible cwt code.

jaxwt.continuous_transform.cwt(data: Array, scales: Union[ndarray, Array], wavelet: Union[ContinuousWavelet, str], sampling_period: float = 1.0) Tuple[Array, Array][source]

Compute the single dimensional continuous wavelet transform.

This function is a jax port of pywt.cwt as found at: https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py

Parameters
  • data (jnp.ndarray) – The ijnput tensor of shape [batch_size, time].

  • scales (np.ndarray or jnp.array) – The wavelet scales to use. One can use f = pywt.scale2frequency(wavelet, scale)/sampling_period to determine what physical frequency, f. Here, f is in hertz when the sampling_period is given in seconds. wavelet (str or Wavelet of ContinuousWavelet): The wavelet to work with.

  • wavelet (ContinuousWavelet or str) – The continuous wavelet to work with.

  • sampling_period (float) – Sampling period for the frequencies output (optional). The values computed for coefs are independent of the choice of sampling_period (i.e. scales is not scaled by the sampling period).

Raises

ValueError – If a scale is too small for the input signal.

Returns

A tuple with the transformation matrix

and frequencies in this order.

Return type

Tuple[jnp.ndarray, jnp.ndarray]

Example

>>> import ptwt
>>> import numpy as jnp
>>> import scipy.signal as signal
>>> t = jnp.linspace(-2, 2, 800, endpoint=False)
>>> sig = signal.chirp(t, f0=1, f1=12, t1=2, method="linear")
>>> widths = jnp.arange(1, 31)
>>> cwtmatr, freqs = ptwt.cwt(
>>>     jnp.array(sig), widths, "mexh", sampling_period=(4 / 800) * jnp.pi
>>> )

jaxwt.packets module

Compute wavelet packets using jwt.

class jaxwt.packets.WaveletPacket(data: Array, wavelet: Wavelet, mode: str = 'reflect', max_level: Optional[int] = None)[source]

Bases: UserDict

A wavelet packet tree.

get_level(level: int) List[str][source]

Return the graycodes for a given level.

Parameters

level (int) – The required depth of the tree.

Returns

A list with the node names.

Return type

list

class jaxwt.packets.WaveletPacket2D(data: Array, wavelet: Union[str, Wavelet], mode: str = 'reflect', max_level: Optional[int] = None)[source]

Bases: UserDict

A wavelet packet tree.

get_level(level: int) List[str][source]

Return the graycodes for a given level.

Parameters

level (int) – The required depth of the tree.

Returns

A list with the node names.

Return type

list

jaxwt.utils module

Various utility functions.

jaxwt.utils.flatten_2d_coeff_lst(coeff_list_2d: List[Union[Array, Tuple[Array, Array, Array]]], flatten_arrays: bool = True) List[Array][source]

Flattens a list of array tuples into a single list.

Parameters
  • coeff_list_2d (list) – A pywt-style coefficient list.

  • flatten_arrays (bool) – If true, 2d array are flattened. Defaults to True.

Returns

A single 1-d list with all original elements.

Return type

list

jaxwt.version module

Version information for jwt.

Run with python -m jaxwt.version

jaxwt.version.get_git_hash() str[source]

Get the jaxwt git hash.

jaxwt.version.get_version(with_git_hash: bool = False) str[source]

Get the jaxwt version string, including a git hash.

Module contents

Differentiable and gpu enabled fast wavelet transforms in JAX.