jaxwt package

jaxwt.conv_fwt module

Convolution based fast wavelet transforms.

jaxwt.conv_fwt.wavedec(data: Array, wavelet: Wavelet | str, mode: str = 'symmetric', level: int | None = None, axis: int = -1, precision: str = 'highest') List[Array][source]

Compute the analysis wavelet transform of the last dimension.

Parameters:
  • data (jnp.ndarray) – Input data array. I.e. of shape [batch, time].

  • wavelet (pywt.Wavelet) – A wavelet name-string or a wavelet object containing the wavelet filter arrays. Check pywt.wavelist() for a list of options.

  • mode (str) – The padding used to extend the input signal. Choose reflect, symmetric or zero. Defaults to symmetric.

  • level (int) – Max scale level to be used, of none as many levels as possible are used. Defaults to None.

  • axis (int) – Compute the transform over this axis instead of the last one. Defaults to -1.

  • precision (str) – For desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.

Returns:

List containing the wavelet coefficients. The coefficients are in pywt order:

[cA_n, cD_n, cD_n-1, …, cD2, cD1].

A denotes approximation and D detail coefficients.

Return type:

list

Raises:

ValueError – If the axis argument is not an integer.

Examples

>>> import pywt
>>> import jaxwt as jwt
>>> import jax.numpy as jnp
>>> # generate an input of even length.
>>> data = jnp.array([0., 1., 2., 3, 4, 5, 5, 4, 3, 2, 1, 0])
>>> jwt.wavedec(data, wavelet=pywt.Wavelet('haar'), level=2)
jaxwt.conv_fwt.waverec(coeffs: List[Array], wavelet: Wavelet | str, axis: int = -1, precision: str = 'highest') Array[source]

Reconstruct the original signal in one dimension.

Parameters:
  • coeffs (List[jnp.ndarray]) – Wavelet coefficients, typically produced by the wavedec function. List entries of shape [batch_size, coefficients] work.

  • wavelet (Union[pywt.Wavelet, str]) – A string with a wavelet name or a wavelet object containing the wavelet filters used to evaluate the decomposition.

  • axis (int) – Transform this axis instead of the last one. Defaults to -1.

  • precision (str) – The desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.

Returns:

Reconstruction of the original data.

Return type:

jnp.ndarray

Raises:

ValueError – If the axis argument is not an integer.

Examples

>>> import pywt
>>> import jaxwt as jwt
>>> import jax.numpy as jnp
>>> # generate an input of even length.
>>> data = jnp.array([0., 1., 2., 3, 4, 5, 5, 4, 3, 2, 1, 0])
>>> transformed = jwt.wavedec(data, pywt.Wavelet('haar'))
>>> jwt.waverec(transformed, pywt.Wavelet('haar'))

jaxwt.conv_fwt_2d module

Two dimensional convolution based fast wavelet transforms.

jaxwt.conv_fwt_2d.wavedec2(data: Array, wavelet: Wavelet | str, mode: str = 'symmetric', level: int | None = None, axes: Tuple[int, int] = (-2, -1), precision: str = 'highest') List[Array | Tuple[Array, Array, Array]][source]

Compute the two-dimensional wavelet analysis transform on the last two dimensions of the input data array.

Parameters:
  • data (jnp.ndarray) – Jax array containing the data to be transformed. A possible input shape would be [batch size, height, width].

  • wavelet (Union[pywt.Wavelet, str]) – A wavelet object or wavelet string for the transformation. Check pywt.wavelist() for a list of options.

  • mode (str) – The desired padding mode. Choose “reflect”, “symmetric” or “zero”. Defaults to symmetric.

  • level (int) – The max level to be used, if not set as many levels as possible will be used. Defaults to None.

  • axes (Tuple[int, int]) – Compute the transform over these axes instead of the last two. Defaults to (-2, -1).

  • precision (str) – For the desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.

Returns:

The wavelet coefficients of shape [batch, height, width] in a nested list.

The coefficients are in pywt order. That is: [cAn, (cHn, cVn, cDn), … (cH1, cV1, cD1)]. A denotes approximation, H horizontal, V vertical and D diagonal coefficients.

Return type:

list

Raises:

ValueError – If the axes tuple does not have two elements or contains a repetition.

Examples

>>> import pywt, scipy.datasets
>>> import jaxwt as jwt
>>> import jax.numpy as jnp
>>> face = jnp.transpose(scipy.datasets.face(), [2, 0, 1])
>>> face = face.astype(jnp.float32)
>>> jwt.wavedec2(face, "haar", level=2)
jaxwt.conv_fwt_2d.waverec2(coeffs: List[Array | Tuple[Array, Array, Array]], wavelet: Wavelet | str, axes: Tuple[int, int] = (-2, -1), precision: str = 'highest') Array[source]

Compute a two-dimensional synthesis wavelet transform.

Use it to reconstruct the original input image from the wavelet coefficients.

Parameters:
  • coeffs (list) – The input coefficients, typically the output of wavedec2.

  • wavelet (Union[pywt.Wavelet, str]) – The wavelet to use for the synthesis transform.

  • axes (Tuple[int, int]) – Compute the transform over these axes instead of the last two. Defaults to (-2, -1).

  • precision (str) – For desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.

Raises:

ValueError – If the axes tuple does not have two elements or contains a repetition.

Returns:

Reconstruction of the original input data array

of shape [batch, height, width].

Return type:

jnp.ndarray

Example

>>> import pywt, scipy.datasets
>>> import jaxwt as jwt
>>> import jax.numpy as jnp
>>> face = jnp.transpose(scipy.datasets.face(), [2, 0, 1])
>>> face = face.astype(jnp.float32)
>>> transformed = jwt.wavedec2(face, "haar")
>>> jwt.waverec2(transformed, "haar")

jaxwt.conv_fwt_3d module

Three-dimensional transformation support.

jaxwt.conv_fwt_3d.wavedec3(data: Array, wavelet: Wavelet | str, mode: str = 'symmetric', level: int | None = None, axes: Tuple[int, int, int] = (-3, -2, -1), precision: str = 'highest') List[Array | Dict[str, Array]][source]

Compute the three-dimensional wavelet analysis transform on the last three dimensions of the input data array.

Parameters:
  • data (jnp.ndarray) – Jax array containing the data to be transformed. A possible input shape would be [batch size, channels, height, width].

  • wavelet (Union[pywt.Wavelet, str]) – A wavelet object or str for the transformation.

  • mode (str) – The desired padding mode. Choose reflect, symmetric or zero. Defaults to symmetric.

  • level (int) – The max level to be used, if not set as many levels as possible will be used. Defaults to None.

  • axes (Tuple[int, int, int]) – Compute the transform over these axes instead of the last three. Defaults to (-3, -2, -1).

  • precision (str) – For desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.

Returns:

A list with the lll coefficients and dictionaries with the filter order strings:

("aad", "ada", "add", "daa", "dad", "dda", "ddd")

as keys. With a for the low pass or approximation filter and d for the high-pass or detail filter.

Return type:

list

Raises:
  • ValueError – If the input has less than three dimensions.

  • ValueError – If the axes tuple does not have three elements or contains a repetition.

Examples

>>> import pywt
>>> import jaxwt as jwt
>>> import jax
>>> data = jax.random.uniform(jax.random.PRNGKey(42),
>>>                           [3, 16, 16, 16])
>>> jwt.wavedec3(data, "haar", level=2)
jaxwt.conv_fwt_3d.waverec3(coeffs: List[Array | Dict[str, Array]], wavelet: Wavelet | str, axes: Tuple[int, int, int] = (-3, -2, -1), precision: str = 'highest') Array[source]

Compute a three-dimensional synthesis wavelet transform.

Use it to reconstruct the original input image from the wavelet coefficients.

Parameters:
  • coeffs (list) – The input coefficients, typically the output of wavedec3.

  • wavelet (Union[pywt.Wavelet, str]) – The wavelet we want.

  • axes (Tuple[int, int, int]) – Transform these axes instead of the last three. Defaults to (-3, -2, -1).

  • precision (str) – For desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.

Returns:

Reconstruction of the original input data array.

For example of shape [batch, channels, height, width].

Return type:

jnp.ndarray

Raises:

ValueError – If the axes list does not have three elements or contains a repetition.

Example

>>> import pywt
>>> import jaxwt as jwt
>>> import jax
>>> data = jax.random.uniform(jax.random.PRNGKey(42),
>>>                           [3, 16, 16, 16])
>>> rec = jwt.waverec3(jwt.wavedec3(data, "haar", level=2), "haar")
>>> jax.numpy.allclose(data, rec)

jaxwt.continuous_transform

Jax compatible cwt code.

jaxwt.continuous_transform.cwt(data: Array, scales: ndarray | Array, wavelet: ContinuousWavelet | str, sampling_period: float = 1.0) Tuple[Array, Array][source]

Compute the single-dimensional continuous wavelet transform.

This function is a jax port of pywt.cwt as found at: https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py

Parameters:
  • data (jnp.ndarray) – The input tensor of shape [batch_size, time].

  • scales (np.ndarray or jnp.array) – The wavelet scales to use. One can use f = pywt.scale2frequency(wavelet, scale)/sampling_period to determine what physical frequency, f. Here, f is in hertz when the sampling_period is given in seconds. wavelet (str or Wavelet of ContinuousWavelet): The wavelet to work with.

  • wavelet (ContinuousWavelet or str) – The continuous wavelet to work with.

  • sampling_period (float) – Sampling period for the frequencies output (optional). The values computed for coefs are independent of the choice of sampling_period (i.e. scales is not scaled by the sampling period).

Raises:

ValueError – If a scale is too small for the input signal.

Returns:

A tuple with the transformation matrix

and frequencies in this order.

Return type:

Tuple[jnp.ndarray, jnp.ndarray]

Example

>>> import jaxwt as jwt
>>> import jax.numpy as jnp
>>> import scipy.signal as signal
>>> t = jnp.linspace(-2, 2, 800, endpoint=False)
>>> sig = signal.chirp(t, f0=1, f1=12, t1=2, method="linear")
>>> widths = jnp.arange(1, 31)
>>> cwtmatr, freqs = jwt.cwt(
>>>     jnp.array(sig), widths, "mexh",
>>>     sampling_period=(4 / 800) * jnp.pi
>>> )

jaxwt.packets module

Compute wavelet packets using jwt.

class jaxwt.packets.WaveletPacket(data: Array, wavelet: Wavelet | str, mode: str = 'symmetric', max_level: int | None = None)[source]

Bases: UserDict

A wavelet packet tree.

Create a wavelet packet decomposition object.

Parameters:
  • data (jnp.ndarray) – The input data array of shape [batch_size, time].

  • wavelet (Union[pywt.Wavelet, str]) – The wavelet used for the decomposition.

  • mode (str) – The desired padding method. Choose i.e. “reflect”, “symmetric” or “zero”. Defaults to “symmetric”.

Example

>>> import pywt
>>> import jax.numpy as jnp
>>> from jaxwt import WaveletPacket
>>> import scipy.signal as signal
>>> wavelet = pywt.Wavelet("db4")
>>> t = jnp.linspace(0, 10, 5001)
>>> w = signal.chirp(t, f0=0.00001,
>>>                  f1=20, t1=10, method="linear")
>>> wp = WaveletPacket(data=w, wavelet=wavelet)
>>> nodes = wp.get_level(7)
>>> jnp_lst = []
>>> for node in nodes:
>>>     jnp_lst.append(wp[node])
>>> viz = jnp.concatenate(jnp_lst)
get_level(level: int) List[str][source]

Return the graycodes for a given level.

Parameters:

level (int) – The required depth of the tree.

Returns:

A list with the node names.

Return type:

list

reconstruct() WaveletPacket[source]

Recursively reconstruct the input starting from the leaf nodes.

Reconstruction replaces the input-data originally assigned to this object.

Note

Only changes to leaf node data impacts the results, since changes in all other nodes will be replaced with a reconstruction from the leafs.

Example

>>> import jaxwt as jwt
>>> import jax
>>> key = jax.random.PRNGKey(0)
>>> input_data = jax.random.normal(key, (1, 24))
>>> jwp = jwt.WaveletPacket(input_data, "haar", max_level=2)
>>> jwp["a" * 2] *= 0
>>> jwp.reconstruct()
>>> print(jwp[""])
class jaxwt.packets.WaveletPacket2D(data: Array, wavelet: str | Wavelet, mode: str = 'symmetric', max_level: int | None = None)[source]

Bases: UserDict

A wavelet packet tree.

Create a 2D-wavelet packet decomposition object.

Example code illustrating the use of this class is available at: https://github.com/v0lta/Jax-Wavelet-Toolbox/tree/master/examples/deepfake_analysis

Parameters:
  • data (jnp.ndarray) – The input data array of shape [batch_size, height, width].

  • wavelet (pywt.Wavelet or str) – The wavelet used for the decomposition.

  • mode (str) – The desired padding method. Choose i.e. “reflect”, “symmetric” or “zero”. Defaults to “symmetric”.

  • max_level (int, optional) – Choose the desired decomposition level.

reconstruct() WaveletPacket2D[source]

Recursively reconstruct the input starting from the leaf nodes.

Note

Only changes to leaf node data impacts the results, since changes in all other nodes will be replaced with a reconstruction from the leafs.

jaxwt.stationary_transform

Code for stationary wavelet transforms.

jaxwt.stationary_transform.iswt(coeffs: List[Array], wavelet: Wavelet | str, axis: int = -1, precision: str | None = 'highest') Array[source]

Compute an inverse stationary wavelet transform.

Parameters:
  • coeffs (List[jnp.ndarray]) – The coefficients as computed by the analysis code.

  • wavelet (Union[pywt.Wavelet, str]) – The wavelet used by the transform.

  • axis (int) – Transform this axis instead of the last one. Defaults to -1.

  • precision (Optional[str]) – Precision value for the underlying lax convolution code. Defaults to “highest”.

Raises:

ValueError – If the axis argument is not an integer.

Returns:

The reconstruction of the original signal.

Return type:

jnp.ndarray

Example

>>> import jax, jaxwt
>>> import jax.numpy as jnp
>>> signal = jax.random.randint(
        jax.random.PRNGKey(42), [1, 10], 0, 9).astype(jnp.float32)
>>> jaxwt.iswt(jaxwt.swt(signal, "haar", level=2), "haar")
jaxwt.stationary_transform.swt(data: Array, wavelet: Wavelet | str, level: int | None = None, axis: int = -1, precision: str = 'highest') List[Array][source]

Compute a multilevel 1d stationary wavelet transform.

Parameters:
  • data (jnp.ndarray) – The input data of shape [batch_size, time]. This function assumes a trailing input dimension with a length divisible by two.

  • wavelet (Union[Wavelet, str]) – The wavelet to use.

  • level (Optional[int], optional) – The number of levels to compute

  • axis (int) – Compute the transform over this axis instead of the last one. Defaults to -1.

  • precision (str) – For the desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.

Returns:

A list containing the wavelet coefficients.

The coefficients are in pywt order: [cA_n, cD_n, cD_n-1, …, cD2, cD1]. A denotes approximation and D detail coefficients. The ordering is identical to the wavedec function. Equivalent to pywt.swt with trim_approx=True.

Return type:

List[jnp.ndarray]

Raises:

ValueError – If the axis argument is not an integer.

Example

>>> import jax, jaxwt
>>> import jax.numpy as jnp
>>> signal = jax.random.randint(
>>>     jax.random.PRNGKey(42), [1, 10], 0, 9).astype(jnp.float32)
>>> jaxwt.swt(signal, "haar", level=2)

jaxwt.version module

Version information for jwt.

Run with python -m jaxwt.version

jaxwt.version.get_version() str[source]

Get the jaxwt version string.