jaxwt package¶
Submodules¶
jaxwt.conv_fwt module¶
Convolution based fast wavelet transforms.
- jaxwt.conv_fwt.wavedec(data: Array, wavelet: Wavelet, level: Optional[int] = None, mode: str = 'reflect') List[Array][source]¶
Compute the one dimensional analysis wavelet transform of the last dimension.
- Parameters
data (jnp.ndarray) – Input data array of shape [batch, time].
wavelet (pywt.Wavelet) – The named tuple containing the wavelet filter arrays.
level (int) – Max scale level to be used, of none as many levels as possible are used. Defaults to None.
mode (str) – The padding used to extend the input signal. Choose reflect, symmetric or zero. Defaults to reflect.
- Returns
- List containing the wavelet coefficients.
The coefficients are in pywt order: [cA_n, cD_n, cD_n-1, …, cD2, cD1]. A denotes approximation and D detail coefficients.
- Return type
list
Examples
>>> import pywt >>> import jaxwt as jwt >>> import jax.numpy as jnp >>> # generate an input of even length. >>> data = jnp.array([0., 1., 2., 3, 4, 5, 5, 4, 3, 2, 1, 0]) >>> jwt.wavedec(data, wavelet=pywt.Wavelet('haar'), level=2)
- jaxwt.conv_fwt.waverec(coeffs: List[Array], wavelet: Wavelet) Array[source]¶
Reconstruct the original signal in one dimension.
- Parameters
coeffs (list) – Wavelet coefficients, typically produced by the wavedec function. List entries of shape [batch_size, coefficients] work.
wavelet (pywt.Wavelet) – The named tuple containing the wavelet filters used to evaluate the decomposition.
- Returns
Reconstruction of the original data.
- Return type
jnp.array
Examples
>>> import pywt >>> import jaxwt as jwt >>> import jax.numpy as jnp >>> # generate an input of even length. >>> data = jnp.array([0., 1., 2., 3, 4, 5, 5, 4, 3, 2, 1, 0]) >>> transformed = jwt.wavedec(data, pywt.Wavelet('haar')) >>> jwt.waverec(transformed, pywt.Wavelet('haar'))
jaxwt.conv_fwt_2d module¶
Two dimensional convolution based fast wavelet transforms.
- jaxwt.conv_fwt_2d.construct_2d_filt(lo: Array, hi: Array) Array[source]¶
Construct 2d filters from 1d inputs using outer products.
- Parameters
lo (jnp.ndarray) – 1d lowpass input filter of size [1, length].
hi (jnp.ndarray) – 1d highpass input filter of size [1, length].
- Returns
jnp.array: 2d filter arrays of shape [4, 1, length, length].
- jaxwt.conv_fwt_2d.wavedec2(data: Array, wavelet: Wavelet, level: Optional[int] = None, mode: str = 'reflect') List[Union[Array, Tuple[Array, Array, Array]]][source]¶
Compute the two dimensional wavelet analysis transform on the last two dimensions of the input data array.
- Parameters
data (jnp.ndarray) – Jax array containing the data to be transformed. Assumed shape: [batch size, hight, width].
wavelet (pywt.Wavelet) – A namedtouple containing the filters for the transformation.
level (int) – The max level to be used, if not set as many levels as possible will be used. Defaults to None.
mode (str) – The desired padding mode. Choose reflect, symmetric or zero. Defaults to reflect.
- Returns
- The wavelet coefficients in a nested list.
The coefficients are in pywt order. That is: [cAn, (cHn, cVn, cDn), … (cH1, cV1, cD1)]. A denotes approximation, H horizontal, V vertical and D diagonal coefficients.
- Return type
list
Examples
>>> import pywt, scipy.misc >>> import jaxwt as jwt >>> import jax.numpy as jnp >>> face = jnp.transpose(scipy.misc.face(), [2, 0, 1]) >>> face = face.astype(jnp.float64) >>> jwt.wavedec2(face, pywt.Wavelet("haar"), level=2)
- jaxwt.conv_fwt_2d.waverec2(coeffs: List[Union[Array, Tuple[Array, Array, Array]]], wavelet: Wavelet) Array[source]¶
Compute a two dimensional synthesis wavelet transfrom.
Use it to reconstruct the original input image from the wavelet coefficients.
- Parameters
coeffs (list) – The input coefficients, typically the output of wavedec2.
wavelet (pywt.Wavelet) – The named tuple contining the filters used to compute the analysis transform.
- Returns
Reconstruction of the original input data array of shape [batch, height, width].
- Return type
jnp.array
- Raises
ValueError – If coeffs is not in the shape as it is returned from wavedec2.
Example
>>> import pywt, scipy.misc >>> import jaxwt as jwt >>> import jax.numpy as jnp >>> face = jnp.transpose(scipy.misc.face(), [2, 0, 1]) >>> face = face.astype(jnp.float64) >>> transformed = jwt.wavedec2(face, pywt.Wavelet("haar")) >>> jwt.waverec2(transformed, pywt.Wavelet("haar"))
jaxwt.continuous_transform¶
Jax compatible cwt code.
- jaxwt.continuous_transform.cwt(data: Array, scales: Union[ndarray, Array], wavelet: Union[ContinuousWavelet, str], sampling_period: float = 1.0) Tuple[Array, Array][source]¶
Compute the single dimensional continuous wavelet transform.
This function is a jax port of pywt.cwt as found at: https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py
- Parameters
data (jnp.ndarray) – The ijnput tensor of shape [batch_size, time].
scales (np.ndarray or jnp.array) – The wavelet scales to use. One can use
f = pywt.scale2frequency(wavelet, scale)/sampling_periodto determine what physical frequency,f. Here,fis in hertz when thesampling_periodis given in seconds. wavelet (str or Wavelet of ContinuousWavelet): The wavelet to work with.wavelet (ContinuousWavelet or str) – The continuous wavelet to work with.
sampling_period (float) – Sampling period for the frequencies output (optional). The values computed for
coefsare independent of the choice ofsampling_period(i.e.scalesis not scaled by the sampling period).
- Raises
ValueError – If a scale is too small for the input signal.
- Returns
- A tuple with the transformation matrix
and frequencies in this order.
- Return type
Tuple[jnp.ndarray, jnp.ndarray]
Example
>>> import jaxwt as jwt >>> import jax.numpy as jnp >>> import scipy.signal as signal >>> t = jnp.linspace(-2, 2, 800, endpoint=False) >>> sig = signal.chirp(t, f0=1, f1=12, t1=2, method="linear") >>> widths = jnp.arange(1, 31) >>> cwtmatr, freqs = jwt.cwt( >>> jnp.array(sig), widths, "mexh", >>> sampling_period=(4 / 800) * jnp.pi >>> )
jaxwt.packets module¶
Compute wavelet packets using jwt.
- class jaxwt.packets.WaveletPacket(data: Array, wavelet: Wavelet, mode: str = 'reflect', max_level: Optional[int] = None)[source]¶
Bases:
UserDictA wavelet packet tree.
Create a wavelet packet decomposition object.
- Parameters
data (jnp.ndarray) – The input data array of shape [batch_size, time].
wavelet (pywt.Wavelet) – The wavelet used for the decomposition.
mode (str) – The desired padding method. Choose i.e. “reflect”, “symmetric” or “zero”. Defaults to “reflect”.
Example
>>> import pywt >>> import jax.numpy as jnp >>> from jaxwt import WaveletPacket >>> import scipy.signal as signal >>> wavelet = pywt.Wavelet("db4") >>> t = jnp.linspace(0, 10, 5001) >>> w = signal.chirp(t, f0=0.00001, >>> f1=20, t1=10, method="linear") >>> wp = WaveletPacket(data=w, wavelet=wavelet) >>> nodes = wp.get_level(7) >>> jnp_lst = [] >>> for node in nodes: >>> jnp_lst.append(wp[node]) >>> viz = jnp.concatenate(jnp_lst)
- get_level(level: int) List[str][source]¶
Return the graycodes for a given level.
- Parameters
level (int) – The required depth of the tree.
- Returns
A list with the node names.
- Return type
list
- reconstruct() WaveletPacket[source]¶
Recursively reconstruct the input starting from the leaf nodes.
Reconstruction replaces the input-data originally assigned to this object.
Note
Only changes to leaf node data impacts the results, since changes in all other nodes will be replaced with a reconstruction from the leafs.
Example
>>> import jaxwt as jwt >>> import jax >>> key = jax.random.PRNGKey(0) >>> input_data = jax.random.normal(key, (1, 24)) >>> jwp = jwt.WaveletPacket(input_data, "haar", max_level=2) >>> jwp["a" * 2] *= 0 >>> jwp.reconstruct() >>> print(jwp[""])
- class jaxwt.packets.WaveletPacket2D(data: Array, wavelet: Union[str, Wavelet], mode: str = 'reflect', max_level: Optional[int] = None)[source]¶
Bases:
UserDictA wavelet packet tree.
Create a 2D-wavelet packet decomposition object.
Example code illustrating the use of this class is available at: https://github.com/v0lta/Jax-Wavelet-Toolbox/tree/packet-patch/examples/deepfake_analysis
- Parameters
data (jnp.ndarray) – The input data array of shape [batch_size, height, width].
wavelet (pywt.Wavelet or str) – The wavelet used for the decomposition.
mode (str) – The desired padding method. Choose i.e. “reflect”, “symmetric” or “zero”. Defaults to “reflect”.
max_level (int, optional) – Choose the desired decomposition level.
- reconstruct() WaveletPacket2D[source]¶
Recursively reconstruct the input starting from the leaf nodes.
Note
Only changes to leaf node data impacts the results, since changes in all other nodes will be replaced with a reconstruction from the leafs.
jaxwt.utils module¶
Various utility functions.
- jaxwt.utils.flatten_2d_coeff_lst(coeff_list_2d: List[Union[Array, Tuple[Array, Array, Array]]], flatten_arrays: bool = True) List[Array][source]¶
Flattens a list of array tuples into a single list.
- Parameters
coeff_list_2d (list) – A pywt-style coefficient list.
flatten_arrays (bool) – If true, 2d array are flattened. Defaults to True.
- Returns
A single 1-d list with all original elements.
- Return type
list
jaxwt.version module¶
Version information for jwt.
Run with python -m jaxwt.version