jaxwt package¶
jaxwt.conv_fwt module¶
Convolution based fast wavelet transforms.
- jaxwt.conv_fwt.wavedec(data: Array, wavelet: Wavelet | str, mode: str = 'symmetric', level: int | None = None, axis: int = -1, precision: str = 'highest') List[Array] [source]¶
Compute the analysis wavelet transform of the last dimension.
- Parameters:
data (jnp.ndarray) – Input data array. I.e. of shape [batch, time].
wavelet (pywt.Wavelet) – A wavelet name-string or a wavelet object containing the wavelet filter arrays. Check pywt.wavelist() for a list of options.
mode (str) – The padding used to extend the input signal. Choose reflect, symmetric or zero. Defaults to symmetric.
level (int) – Max scale level to be used, of none as many levels as possible are used. Defaults to None.
axis (int) – Compute the transform over this axis instead of the last one. Defaults to -1.
precision (str) – For desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.
- Returns:
List containing the wavelet coefficients. The coefficients are in
pywt
order:[cA_n, cD_n, cD_n-1, …, cD2, cD1].
A denotes approximation and D detail coefficients.
- Return type:
list
- Raises:
ValueError – If the axis argument is not an integer.
Examples
>>> import pywt >>> import jaxwt as jwt >>> import jax.numpy as jnp >>> # generate an input of even length. >>> data = jnp.array([0., 1., 2., 3, 4, 5, 5, 4, 3, 2, 1, 0]) >>> jwt.wavedec(data, wavelet=pywt.Wavelet('haar'), level=2)
- jaxwt.conv_fwt.waverec(coeffs: List[Array], wavelet: Wavelet | str, axis: int = -1, precision: str = 'highest') Array [source]¶
Reconstruct the original signal in one dimension.
- Parameters:
coeffs (List[jnp.ndarray]) – Wavelet coefficients, typically produced by the
wavedec
function. List entries of shape [batch_size, coefficients] work.wavelet (Union[pywt.Wavelet, str]) – A string with a wavelet name or a wavelet object containing the wavelet filters used to evaluate the decomposition.
axis (int) – Transform this axis instead of the last one. Defaults to -1.
precision (str) – The desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.
- Returns:
Reconstruction of the original data.
- Return type:
jnp.ndarray
- Raises:
ValueError – If the axis argument is not an integer.
Examples
>>> import pywt >>> import jaxwt as jwt >>> import jax.numpy as jnp >>> # generate an input of even length. >>> data = jnp.array([0., 1., 2., 3, 4, 5, 5, 4, 3, 2, 1, 0]) >>> transformed = jwt.wavedec(data, pywt.Wavelet('haar')) >>> jwt.waverec(transformed, pywt.Wavelet('haar'))
jaxwt.conv_fwt_2d module¶
Two dimensional convolution based fast wavelet transforms.
- jaxwt.conv_fwt_2d.wavedec2(data: Array, wavelet: Wavelet | str, mode: str = 'symmetric', level: int | None = None, axes: Tuple[int, int] = (-2, -1), precision: str = 'highest') List[Array | Tuple[Array, Array, Array]] [source]¶
Compute the two-dimensional wavelet analysis transform on the last two dimensions of the input data array.
- Parameters:
data (jnp.ndarray) – Jax array containing the data to be transformed. A possible input shape would be [batch size, height, width].
wavelet (Union[pywt.Wavelet, str]) – A wavelet object or wavelet string for the transformation. Check pywt.wavelist() for a list of options.
mode (str) – The desired padding mode. Choose “reflect”, “symmetric” or “zero”. Defaults to symmetric.
level (int) – The max level to be used, if not set as many levels as possible will be used. Defaults to None.
axes (Tuple[int, int]) – Compute the transform over these axes instead of the last two. Defaults to (-2, -1).
precision (str) – For the desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.
- Returns:
- The wavelet coefficients of shape [batch, height, width] in a nested list.
The coefficients are in pywt order. That is: [cAn, (cHn, cVn, cDn), … (cH1, cV1, cD1)]. A denotes approximation, H horizontal, V vertical and D diagonal coefficients.
- Return type:
list
- Raises:
ValueError – If the axes tuple does not have two elements or contains a repetition.
Examples
>>> import pywt, scipy.datasets >>> import jaxwt as jwt >>> import jax.numpy as jnp >>> face = jnp.transpose(scipy.datasets.face(), [2, 0, 1]) >>> face = face.astype(jnp.float32) >>> jwt.wavedec2(face, "haar", level=2)
- jaxwt.conv_fwt_2d.waverec2(coeffs: List[Array | Tuple[Array, Array, Array]], wavelet: Wavelet | str, axes: Tuple[int, int] = (-2, -1), precision: str = 'highest') Array [source]¶
Compute a two-dimensional synthesis wavelet transform.
Use it to reconstruct the original input image from the wavelet coefficients.
- Parameters:
coeffs (list) – The input coefficients, typically the output of wavedec2.
wavelet (Union[pywt.Wavelet, str]) – The wavelet to use for the synthesis transform.
axes (Tuple[int, int]) – Compute the transform over these axes instead of the last two. Defaults to (-2, -1).
precision (str) – For desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.
- Raises:
ValueError – If the axes tuple does not have two elements or contains a repetition.
- Returns:
- Reconstruction of the original input data array
of shape [batch, height, width].
- Return type:
jnp.ndarray
Example
>>> import pywt, scipy.datasets >>> import jaxwt as jwt >>> import jax.numpy as jnp >>> face = jnp.transpose(scipy.datasets.face(), [2, 0, 1]) >>> face = face.astype(jnp.float32) >>> transformed = jwt.wavedec2(face, "haar") >>> jwt.waverec2(transformed, "haar")
jaxwt.conv_fwt_3d module¶
Three-dimensional transformation support.
- jaxwt.conv_fwt_3d.wavedec3(data: Array, wavelet: Wavelet | str, mode: str = 'symmetric', level: int | None = None, axes: Tuple[int, int, int] = (-3, -2, -1), precision: str = 'highest') List[Array | Dict[str, Array]] [source]¶
Compute the three-dimensional wavelet analysis transform on the last three dimensions of the input data array.
- Parameters:
data (jnp.ndarray) – Jax array containing the data to be transformed. A possible input shape would be [batch size, channels, height, width].
wavelet (Union[pywt.Wavelet, str]) – A wavelet object or str for the transformation.
mode (str) – The desired padding mode. Choose reflect, symmetric or zero. Defaults to symmetric.
level (int) – The max level to be used, if not set as many levels as possible will be used. Defaults to None.
axes (Tuple[int, int, int]) – Compute the transform over these axes instead of the last three. Defaults to (-3, -2, -1).
precision (str) – For desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.
- Returns:
A list with the lll coefficients and dictionaries with the filter order strings:
("aad", "ada", "add", "daa", "dad", "dda", "ddd")
as keys. With a for the low pass or approximation filter and d for the high-pass or detail filter.
- Return type:
list
- Raises:
ValueError – If the input has less than three dimensions.
ValueError – If the axes tuple does not have three elements or contains a repetition.
Examples
>>> import pywt >>> import jaxwt as jwt >>> import jax >>> data = jax.random.uniform(jax.random.PRNGKey(42), >>> [3, 16, 16, 16]) >>> jwt.wavedec3(data, "haar", level=2)
- jaxwt.conv_fwt_3d.waverec3(coeffs: List[Array | Dict[str, Array]], wavelet: Wavelet | str, axes: Tuple[int, int, int] = (-3, -2, -1), precision: str = 'highest') Array [source]¶
Compute a three-dimensional synthesis wavelet transform.
Use it to reconstruct the original input image from the wavelet coefficients.
- Parameters:
coeffs (list) – The input coefficients, typically the output of wavedec3.
wavelet (Union[pywt.Wavelet, str]) – The wavelet we want.
axes (Tuple[int, int, int]) – Transform these axes instead of the last three. Defaults to (-3, -2, -1).
precision (str) – For desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.
- Returns:
- Reconstruction of the original input data array.
For example of shape [batch, channels, height, width].
- Return type:
jnp.ndarray
- Raises:
ValueError – If the axes list does not have three elements or contains a repetition.
Example
>>> import pywt >>> import jaxwt as jwt >>> import jax >>> data = jax.random.uniform(jax.random.PRNGKey(42), >>> [3, 16, 16, 16]) >>> rec = jwt.waverec3(jwt.wavedec3(data, "haar", level=2), "haar") >>> jax.numpy.allclose(data, rec)
jaxwt.continuous_transform¶
Jax compatible cwt code.
- jaxwt.continuous_transform.cwt(data: Array, scales: ndarray | Array, wavelet: ContinuousWavelet | str, sampling_period: float = 1.0) Tuple[Array, Array] [source]¶
Compute the single-dimensional continuous wavelet transform.
This function is a jax port of pywt.cwt as found at: https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py
- Parameters:
data (jnp.ndarray) – The input tensor of shape [batch_size, time].
scales (np.ndarray or jnp.array) – The wavelet scales to use. One can use
f = pywt.scale2frequency(wavelet, scale)/sampling_period
to determine what physical frequency,f
. Here,f
is in hertz when thesampling_period
is given in seconds. wavelet (str or Wavelet of ContinuousWavelet): The wavelet to work with.wavelet (ContinuousWavelet or str) – The continuous wavelet to work with.
sampling_period (float) – Sampling period for the frequencies output (optional). The values computed for
coefs
are independent of the choice ofsampling_period
(i.e.scales
is not scaled by the sampling period).
- Raises:
ValueError – If a scale is too small for the input signal.
- Returns:
- A tuple with the transformation matrix
and frequencies in this order.
- Return type:
Tuple[jnp.ndarray, jnp.ndarray]
Example
>>> import jaxwt as jwt >>> import jax.numpy as jnp >>> import scipy.signal as signal >>> t = jnp.linspace(-2, 2, 800, endpoint=False) >>> sig = signal.chirp(t, f0=1, f1=12, t1=2, method="linear") >>> widths = jnp.arange(1, 31) >>> cwtmatr, freqs = jwt.cwt( >>> jnp.array(sig), widths, "mexh", >>> sampling_period=(4 / 800) * jnp.pi >>> )
jaxwt.packets module¶
Compute wavelet packets using jwt.
- class jaxwt.packets.WaveletPacket(data: Array, wavelet: Wavelet | str, mode: str = 'symmetric', max_level: int | None = None)[source]¶
Bases:
UserDict
A wavelet packet tree.
Create a wavelet packet decomposition object.
- Parameters:
data (jnp.ndarray) – The input data array of shape [batch_size, time].
wavelet (Union[pywt.Wavelet, str]) – The wavelet used for the decomposition.
mode (str) – The desired padding method. Choose i.e. “reflect”, “symmetric” or “zero”. Defaults to “symmetric”.
Example
>>> import pywt >>> import jax.numpy as jnp >>> from jaxwt import WaveletPacket >>> import scipy.signal as signal >>> wavelet = pywt.Wavelet("db4") >>> t = jnp.linspace(0, 10, 5001) >>> w = signal.chirp(t, f0=0.00001, >>> f1=20, t1=10, method="linear") >>> wp = WaveletPacket(data=w, wavelet=wavelet) >>> nodes = wp.get_level(7) >>> jnp_lst = [] >>> for node in nodes: >>> jnp_lst.append(wp[node]) >>> viz = jnp.concatenate(jnp_lst)
- get_level(level: int) List[str] [source]¶
Return the graycodes for a given level.
- Parameters:
level (int) – The required depth of the tree.
- Returns:
A list with the node names.
- Return type:
list
- reconstruct() WaveletPacket [source]¶
Recursively reconstruct the input starting from the leaf nodes.
Reconstruction replaces the input-data originally assigned to this object.
Note
Only changes to leaf node data impacts the results, since changes in all other nodes will be replaced with a reconstruction from the leafs.
Example
>>> import jaxwt as jwt >>> import jax >>> key = jax.random.PRNGKey(0) >>> input_data = jax.random.normal(key, (1, 24)) >>> jwp = jwt.WaveletPacket(input_data, "haar", max_level=2) >>> jwp["a" * 2] *= 0 >>> jwp.reconstruct() >>> print(jwp[""])
- class jaxwt.packets.WaveletPacket2D(data: Array, wavelet: str | Wavelet, mode: str = 'symmetric', max_level: int | None = None)[source]¶
Bases:
UserDict
A wavelet packet tree.
Create a 2D-wavelet packet decomposition object.
Example code illustrating the use of this class is available at: https://github.com/v0lta/Jax-Wavelet-Toolbox/tree/master/examples/deepfake_analysis
- Parameters:
data (jnp.ndarray) – The input data array of shape [batch_size, height, width].
wavelet (pywt.Wavelet or str) – The wavelet used for the decomposition.
mode (str) – The desired padding method. Choose i.e. “reflect”, “symmetric” or “zero”. Defaults to “symmetric”.
max_level (int, optional) – Choose the desired decomposition level.
- reconstruct() WaveletPacket2D [source]¶
Recursively reconstruct the input starting from the leaf nodes.
Note
Only changes to leaf node data impacts the results, since changes in all other nodes will be replaced with a reconstruction from the leafs.
jaxwt.stationary_transform¶
Code for stationary wavelet transforms.
- jaxwt.stationary_transform.iswt(coeffs: List[Array], wavelet: Wavelet | str, axis: int = -1, precision: str | None = 'highest') Array [source]¶
Compute an inverse stationary wavelet transform.
- Parameters:
coeffs (List[jnp.ndarray]) – The coefficients as computed by the analysis code.
wavelet (Union[pywt.Wavelet, str]) – The wavelet used by the transform.
axis (int) – Transform this axis instead of the last one. Defaults to -1.
precision (Optional[str]) – Precision value for the underlying lax convolution code. Defaults to “highest”.
- Raises:
ValueError – If the axis argument is not an integer.
- Returns:
The reconstruction of the original signal.
- Return type:
jnp.ndarray
Example
>>> import jax, jaxwt >>> import jax.numpy as jnp >>> signal = jax.random.randint( jax.random.PRNGKey(42), [1, 10], 0, 9).astype(jnp.float32) >>> jaxwt.iswt(jaxwt.swt(signal, "haar", level=2), "haar")
- jaxwt.stationary_transform.swt(data: Array, wavelet: Wavelet | str, level: int | None = None, axis: int = -1, precision: str = 'highest') List[Array] [source]¶
Compute a multilevel 1d stationary wavelet transform.
- Parameters:
data (jnp.ndarray) – The input data of shape [batch_size, time]. This function assumes a trailing input dimension with a length divisible by two.
wavelet (Union[Wavelet, str]) – The wavelet to use.
level (Optional[int], optional) – The number of levels to compute
axis (int) – Compute the transform over this axis instead of the last one. Defaults to -1.
precision (str) – For the desired precision, choose “fastest”, “high” or “highest”. Defaults to “highest”.
- Returns:
- A list containing the wavelet coefficients.
The coefficients are in pywt order: [cA_n, cD_n, cD_n-1, …, cD2, cD1]. A denotes approximation and D detail coefficients. The ordering is identical to the
wavedec
function. Equivalent to pywt.swt with trim_approx=True.
- Return type:
List[jnp.ndarray]
- Raises:
ValueError – If the axis argument is not an integer.
Example
>>> import jax, jaxwt >>> import jax.numpy as jnp >>> signal = jax.random.randint( >>> jax.random.PRNGKey(42), [1, 10], 0, 9).astype(jnp.float32) >>> jaxwt.swt(signal, "haar", level=2)
jaxwt.version module¶
Version information for jwt
.
Run with python -m jaxwt.version