jaxwt package

jaxwt.conv_fwt module

Convolution based fast wavelet transforms.

jaxwt.conv_fwt.wavedec(data: Array, wavelet: Wavelet, level: int | None = None, mode: str = 'reflect', precision: str = 'highest') List[Array][source]

Compute the analysis wavelet transform of the last dimension.

Parameters:
  • data (jnp.ndarray) – Input data array of shape [batch, time].

  • wavelet (pywt.Wavelet) – The named tuple containing the wavelet filter arrays.

  • level (int) – Max scale level to be used, of none as many levels as possible are used. Defaults to None.

  • mode (str) – The padding used to extend the input signal. Choose reflect, symmetric or zero. Defaults to reflect.

  • precision (str) – The desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.

Returns:

List containing the wavelet coefficients of shape [batch_size, coefficients].

The coefficients are in pywt order: [cA_n, cD_n, cD_n-1, …, cD2, cD1]. A denotes approximation and D detail coefficients.

Return type:

list

Raises:

ValueError – If the dimensionality of the input data array is unsupported.

Examples

>>> import pywt
>>> import jaxwt as jwt
>>> import jax.numpy as jnp
>>> # generate an input of even length.
>>> data = jnp.array([0., 1., 2., 3, 4, 5, 5, 4, 3, 2, 1, 0])
>>> jwt.wavedec(data, wavelet=pywt.Wavelet('haar'), level=2)
jaxwt.conv_fwt.waverec(coeffs: List[Array], wavelet: Wavelet, precision: str = 'highest') Array[source]

Reconstruct the original signal in one dimension.

Parameters:
  • coeffs (list) – Wavelet coefficients, typically produced by the wavedec function. List entries of shape [batch_size, coefficients] work.

  • wavelet (pywt.Wavelet) – The named tuple containing the wavelet filters used to evaluate the decomposition.

  • precision (str) – The desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.

Returns:

Reconstruction of the original data.

Return type:

jnp.array

Examples

>>> import pywt
>>> import jaxwt as jwt
>>> import jax.numpy as jnp
>>> # generate an input of even length.
>>> data = jnp.array([0., 1., 2., 3, 4, 5, 5, 4, 3, 2, 1, 0])
>>> transformed = jwt.wavedec(data, pywt.Wavelet('haar'))
>>> jwt.waverec(transformed, pywt.Wavelet('haar'))

jaxwt.conv_fwt_2d module

Two dimensional convolution based fast wavelet transforms.

jaxwt.conv_fwt_2d.wavedec2(data: Array, wavelet: Wavelet, level: int | None = None, mode: str = 'reflect', precision: str = 'highest') List[Array | Tuple[Array, Array, Array]][source]

Compute the two dimensional wavelet analysis transform on the last two dimensions of the input data array.

Parameters:
  • data (jnp.ndarray) – Jax array containing the data to be transformed. Assumed shape: [batch size, hight, width].

  • wavelet (pywt.Wavelet) – A namedtouple containing the filters for the transformation.

  • level (int) – The max level to be used, if not set as many levels as possible will be used. Defaults to None.

  • mode (str) – The desired padding mode. Choose reflect, symmetric or zero. Defaults to reflect.

  • precision (str) – The desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.

Returns:

The wavelet coefficients of shape [batch, height, width] in a nested list.

The coefficients are in pywt order. That is: [cAn, (cHn, cVn, cDn), … (cH1, cV1, cD1)]. A denotes approximation, H horizontal, V vertical and D diagonal coefficients.

Return type:

list

Raises:

ValueError – If the dimensionality of the input data array is unsupported.

Examples

>>> import pywt, scipy.datasets
>>> import jaxwt as jwt
>>> import jax.numpy as jnp
>>> face = jnp.transpose(scipy.datasets.face(), [2, 0, 1])
>>> face = face.astype(jnp.float64)
>>> jwt.wavedec2(face, pywt.Wavelet("haar"), level=2)
jaxwt.conv_fwt_2d.waverec2(coeffs: List[Array | Tuple[Array, Array, Array]], wavelet: Wavelet, precision: str = 'highest') Array[source]

Compute a two dimensional synthesis wavelet transfrom.

Use it to reconstruct the original input image from the wavelet coefficients.

Parameters:
  • coeffs (list) – The input coefficients, typically the output of wavedec2.

  • wavelet (pywt.Wavelet) – The named tuple contining the filters used to compute the analysis transform.

  • precision (str) – The desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.

Returns:

Reconstruction of the original input data array of shape [batch, height, width].

Return type:

jnp.array

Raises:

ValueError – If coeffs is not in the shape as it is returned from wavedec2.

Example

>>> import pywt, scipy.datasets
>>> import jaxwt as jwt
>>> import jax.numpy as jnp
>>> face = jnp.transpose(scipy.datasets.face(), [2, 0, 1])
>>> face = face.astype(jnp.float64)
>>> transformed = jwt.wavedec2(face, pywt.Wavelet("haar"))
>>> jwt.waverec2(transformed, pywt.Wavelet("haar"))

jaxwt.continuous_transform

Jax compatible cwt code.

jaxwt.continuous_transform.cwt(data: Array, scales: ndarray | Array, wavelet: ContinuousWavelet | str, sampling_period: float = 1.0) Tuple[Array, Array][source]

Compute the single dimensional continuous wavelet transform.

This function is a jax port of pywt.cwt as found at: https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py

Parameters:
  • data (jnp.ndarray) – The ijnput tensor of shape [batch_size, time].

  • scales (np.ndarray or jnp.array) – The wavelet scales to use. One can use f = pywt.scale2frequency(wavelet, scale)/sampling_period to determine what physical frequency, f. Here, f is in hertz when the sampling_period is given in seconds. wavelet (str or Wavelet of ContinuousWavelet): The wavelet to work with.

  • wavelet (ContinuousWavelet or str) – The continuous wavelet to work with.

  • sampling_period (float) – Sampling period for the frequencies output (optional). The values computed for coefs are independent of the choice of sampling_period (i.e. scales is not scaled by the sampling period).

Raises:

ValueError – If a scale is too small for the input signal.

Returns:

A tuple with the transformation matrix

and frequencies in this order.

Return type:

Tuple[jnp.ndarray, jnp.ndarray]

Example

>>> import jaxwt as jwt
>>> import jax.numpy as jnp
>>> import scipy.signal as signal
>>> t = jnp.linspace(-2, 2, 800, endpoint=False)
>>> sig = signal.chirp(t, f0=1, f1=12, t1=2, method="linear")
>>> widths = jnp.arange(1, 31)
>>> cwtmatr, freqs = jwt.cwt(
>>>     jnp.array(sig), widths, "mexh",
>>>     sampling_period=(4 / 800) * jnp.pi
>>> )

jaxwt.packets module

Compute wavelet packets using jwt.

class jaxwt.packets.WaveletPacket(data: Array, wavelet: Wavelet, mode: str = 'reflect', max_level: int | None = None)[source]

Bases: UserDict

A wavelet packet tree.

Create a wavelet packet decomposition object.

Parameters:
  • data (jnp.ndarray) – The input data array of shape [batch_size, time].

  • wavelet (pywt.Wavelet) – The wavelet used for the decomposition.

  • mode (str) – The desired padding method. Choose i.e. “reflect”, “symmetric” or “zero”. Defaults to “reflect”.

Example

>>> import pywt
>>> import jax.numpy as jnp
>>> from jaxwt import WaveletPacket
>>> import scipy.signal as signal
>>> wavelet = pywt.Wavelet("db4")
>>> t = jnp.linspace(0, 10, 5001)
>>> w = signal.chirp(t, f0=0.00001,
>>>                  f1=20, t1=10, method="linear")
>>> wp = WaveletPacket(data=w, wavelet=wavelet)
>>> nodes = wp.get_level(7)
>>> jnp_lst = []
>>> for node in nodes:
>>>     jnp_lst.append(wp[node])
>>> viz = jnp.concatenate(jnp_lst)
get_level(level: int) List[str][source]

Return the graycodes for a given level.

Parameters:

level (int) – The required depth of the tree.

Returns:

A list with the node names.

Return type:

list

reconstruct() WaveletPacket[source]

Recursively reconstruct the input starting from the leaf nodes.

Reconstruction replaces the input-data originally assigned to this object.

Note

Only changes to leaf node data impacts the results, since changes in all other nodes will be replaced with a reconstruction from the leafs.

Example

>>> import jaxwt as jwt
>>> import jax
>>> key = jax.random.PRNGKey(0)
>>> input_data = jax.random.normal(key, (1, 24))
>>> jwp = jwt.WaveletPacket(input_data, "haar", max_level=2)
>>> jwp["a" * 2] *= 0
>>> jwp.reconstruct()
>>> print(jwp[""])
class jaxwt.packets.WaveletPacket2D(data: Array, wavelet: str | Wavelet, mode: str = 'reflect', max_level: int | None = None)[source]

Bases: UserDict

A wavelet packet tree.

Create a 2D-wavelet packet decomposition object.

Example code illustrating the use of this class is available at: https://github.com/v0lta/Jax-Wavelet-Toolbox/tree/packet-patch/examples/deepfake_analysis

Parameters:
  • data (jnp.ndarray) – The input data array of shape [batch_size, height, width].

  • wavelet (pywt.Wavelet or str) – The wavelet used for the decomposition.

  • mode (str) – The desired padding method. Choose i.e. “reflect”, “symmetric” or “zero”. Defaults to “reflect”.

  • max_level (int, optional) – Choose the desired decomposition level.

reconstruct() WaveletPacket2D[source]

Recursively reconstruct the input starting from the leaf nodes.

Note

Only changes to leaf node data impacts the results, since changes in all other nodes will be replaced with a reconstruction from the leafs.

jaxwt.utils module

Various utility functions.

jaxwt.utils.flatten_2d_coeff_lst(coeff_list_2d: List[Array | Tuple[Array, Array, Array]], flatten_arrays: bool = True) List[Array][source]

Flattens a list of array tuples into a single list.

Parameters:
  • coeff_list_2d (list) – A pywt-style coefficient list.

  • flatten_arrays (bool) – If true, 2d array are flattened. Defaults to True.

Returns:

A single 1-d list with all original elements.

Return type:

list

jaxwt.version module

Version information for jwt.

Run with python -m jaxwt.version

jaxwt.version.get_git_hash() str[source]

Get the jaxwt git hash.

jaxwt.version.get_version(with_git_hash: bool = False) str[source]

Get the jaxwt version string, including a git hash.