Source code for jaxwt.packets

"""Compute wavelet packets using jwt."""

#
# Created on Fri Jun 19 2020
# Copyright (c) 2020 Moritz Wolter
#
import collections
from itertools import product
from typing import TYPE_CHECKING, List, Optional, Union

# import jax
import jax.numpy as jnp
import pywt

from .conv_fwt import wavedec, waverec
from .conv_fwt_2d import wavedec2, waverec2
from .utils import _as_wavelet

if TYPE_CHECKING:
    BaseDict = collections.UserDict[str, jnp.ndarray]
else:
    BaseDict = collections.UserDict


[docs] class WaveletPacket(BaseDict): """A wavelet packet tree.""" def __init__( self, data: jnp.ndarray, wavelet: Union[pywt.Wavelet, str], mode: str = "symmetric", max_level: Optional[int] = None, ): """Create a wavelet packet decomposition object. Args: data (jnp.ndarray): The input data array of shape [batch_size, time]. wavelet (Union[pywt.Wavelet, str]): The wavelet used for the decomposition. mode (str): The desired padding method. Choose i.e. "reflect", "symmetric" or "zero". Defaults to "symmetric". Example: >>> import pywt >>> import jax.numpy as jnp >>> from jaxwt import WaveletPacket >>> import scipy.signal as signal >>> wavelet = pywt.Wavelet("db4") >>> t = jnp.linspace(0, 10, 5001) >>> w = signal.chirp(t, f0=0.00001, >>> f1=20, t1=10, method="linear") >>> wp = WaveletPacket(data=w, wavelet=wavelet) >>> nodes = wp.get_level(7) >>> jnp_lst = [] >>> for node in nodes: >>> jnp_lst.append(wp[node]) >>> viz = jnp.concatenate(jnp_lst) """ if len(data.shape) == 1: self.input_data = jnp.expand_dims(data, 0) elif len(data.shape) == 2: self.input_data = data self.wavelet = _as_wavelet(wavelet) self.mode = mode self.data = {} if max_level is None: self.max_level = pywt.dwt_max_level(data.shape[-1], self.wavelet.dec_len) else: self.max_level = max_level self._recursive_dwt(self.input_data, level=0, path="")
[docs] def get_level(self, level: int) -> List[str]: """Return the graycodes for a given level. Args: level (int): The required depth of the tree. Returns: list: A list with the node names. """ return self._get_graycode_order(level)
def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> List[str]: graycode_order = [x, y] for _ in range(level - 1): graycode_order = [x + path for path in graycode_order] + [ y + path for path in graycode_order[::-1] ] if level == 0: return [""] else: return graycode_order def _recursive_dwt( self, data: jnp.ndarray, level: int, path: str, ) -> None: self.data[path] = data if level < self.max_level: res_lo, res_hi = wavedec( data=data, wavelet=self.wavelet, level=1, mode=self.mode ) self._recursive_dwt(res_lo, level + 1, path + "a") self._recursive_dwt(res_hi, level + 1, path + "d") else: self.data[path] = data
[docs] def reconstruct(self) -> "WaveletPacket": """Recursively reconstruct the input starting from the leaf nodes. Reconstruction replaces the input-data originally assigned to this object. Note: Only changes to leaf node data impacts the results, since changes in all other nodes will be replaced with a reconstruction from the leafs. Example: >>> import jaxwt as jwt >>> import jax >>> key = jax.random.PRNGKey(0) >>> input_data = jax.random.normal(key, (1, 24)) >>> jwp = jwt.WaveletPacket(input_data, "haar", max_level=2) >>> jwp["a" * 2] *= 0 >>> jwp.reconstruct() >>> print(jwp[""]) """ if self.max_level is None: self.max_level = pywt.dwt_max_level( self[""].shape[-1], self.wavelet.dec_len ) for level in reversed(range(self.max_level)): for node in self.get_level(level): data_a = self[node + "a"] data_b = self[node + "d"] rec = waverec([data_a, data_b], self.wavelet) self[node] = rec return self
[docs] class WaveletPacket2D(BaseDict): """A wavelet packet tree.""" def __init__( self, data: jnp.ndarray, wavelet: Union[str, pywt.Wavelet], mode: str = "symmetric", max_level: Optional[int] = None, ): """Create a 2D-wavelet packet decomposition object. Example code illustrating the use of this class is available at: https://github.com/v0lta/Jax-Wavelet-Toolbox/tree/master/examples/deepfake_analysis Args: data (jnp.ndarray): The input data array of shape [batch_size, height, width]. wavelet (pywt.Wavelet or str): The wavelet used for the decomposition. mode (str): The desired padding method. Choose i.e. "reflect", "symmetric" or "zero". Defaults to "symmetric". max_level (int, optional): Choose the desired decomposition level. """ self.input_data = data self.wavelet: pywt.Wavelet = _as_wavelet(wavelet) self.mode = mode self.data = {} if max_level is None: self.max_level = pywt.dwt_max_level( min(self.input_data.shape[-2:]), self.wavelet.dec_len ) else: self.max_level = max_level self._recursive_dwt2d(self.input_data, level=0, path="") def _recursive_dwt2d( self, data: jnp.ndarray, level: int, path: str, ) -> None: self.data[path] = data if level < self.max_level: result_a, (result_h, result_v, result_d) = wavedec2( data=data, wavelet=self.wavelet, level=1, mode=self.mode ) # assert for type checking assert not isinstance(result_a, tuple) self._recursive_dwt2d(result_a, level + 1, path + "a") self._recursive_dwt2d(result_h, level + 1, path + "h") self._recursive_dwt2d(result_v, level + 1, path + "v") self._recursive_dwt2d(result_d, level + 1, path + "d") else: self.data[path] = data def _get_natural_order(self, level: int) -> List[str]: """Get the natural ordering for a given decomposition level. Args: level (int): The decomposition level. Returns: list: A list with the filter order strings. """ return ["".join(p) for p in list(product(["a", "h", "v", "d"], repeat=level))]
[docs] def reconstruct(self) -> "WaveletPacket2D": """Recursively reconstruct the input starting from the leaf nodes. Note: Only changes to leaf node data impacts the results, since changes in all other nodes will be replaced with a reconstruction from the leafs. """ if self.max_level is None: self.max_level = pywt.dwt_max_level( min(self[""].shape[-2:]), self.wavelet.dec_len ) for level in reversed(range(self.max_level)): for node in self._get_natural_order(level): data_a = self[node + "a"] data_h = self[node + "h"] data_v = self[node + "v"] data_d = self[node + "d"] rec = waverec2([data_a, (data_h, data_v, data_d)], self.wavelet) self[node] = rec return self