Source code for jaxwt.packets

"""Compute wavelet packets using jwt."""

#
# Created on Fri Jun 19 2020
# Copyright (c) 2020 Moritz Wolter
#
import collections
from typing import TYPE_CHECKING, List, Optional, Union

import jax
import jax.numpy as jnp
import pywt

from .conv_fwt import _fwt_pad, _get_filter_arrays
from .conv_fwt_2d import wavedec2
from .utils import Wavelet, _as_wavelet

if TYPE_CHECKING:
    BaseDict = collections.UserDict[str, jnp.ndarray]
else:
    BaseDict = collections.UserDict


[docs]class WaveletPacket(BaseDict): """A wavelet packet tree.""" def __init__( self, data: jnp.ndarray, wavelet: Wavelet, mode: str = "reflect", max_level: Optional[int] = None, ): """Create a wavelet packet decomposition object. Args: data (jnp.array): The input data array of shape [batch_size, time]. wavelet (Wavelet): The wavelet used for the decomposition. mode (str): The desired padding method. Choose i.e. "reflect", "symmetric" or "zero". Defaults to "reflect". """ if len(data.shape) == 1: self.input_data = jnp.expand_dims(jnp.expand_dims(data, 0), 0) elif len(data.shape) == 2: self.input_data = jnp.expand_dims(data, 1) self.wavelet = wavelet self.mode = mode self.data = {} dec_lo, dec_hi, _, _ = _get_filter_arrays(wavelet, flip=True) filt_len = dec_lo.shape[-1] filt = jnp.stack([dec_lo, dec_hi], 0) if max_level is None: max_level = pywt.dwt_max_level(data.shape[-1], filt_len) self._recursive_dwt( self.input_data, filt, level=0, max_level=max_level, path="" )
[docs] def get_level(self, level: int) -> List[str]: """Return the graycodes for a given level. Args: level (int): The required depth of the tree. Returns: list: A list with the node names. """ return self._get_graycode_order(level)
def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> List[str]: graycode_order = [x, y] for _ in range(level - 1): graycode_order = [x + path for path in graycode_order] + [ y + path for path in graycode_order[::-1] ] return graycode_order def _recursive_dwt( self, data: jnp.ndarray, filt: jnp.ndarray, level: int, max_level: int, path: str, ) -> None: self.data[path] = jnp.squeeze(data, 1) if level < max_level: data = _fwt_pad(data, filt_len=filt.shape[-1], mode=self.mode) res = jax.lax.conv_general_dilated( lhs=data, # lhs = NCH image tensor rhs=filt, # rhs = OIH conv kernel tensor padding="VALID", window_strides=[ 2, ], dimension_numbers=("NCH", "OIH", "NCH"), ) res_lo, res_hi = jnp.split(res, 2, 1) self._recursive_dwt(res_lo, filt, level + 1, max_level, path + "a") self._recursive_dwt(res_hi, filt, level + 1, max_level, path + "d") else: self.data[path] = jnp.squeeze(data, 1)
[docs]class WaveletPacket2D(BaseDict): """A wavelet packet tree.""" def __init__( self, data: jnp.ndarray, wavelet: Union[str, pywt.Wavelet], mode: str = "reflect", max_level: Optional[int] = None, ): """Create a 2D-wavelet packet decomposition object. Args: data (jnp.array): The input data array of shape [batch_size, height, width]. wavelet (Wavelet): The wavelet used for the decomposition. mode (str): The desired padding method. Choose i.e. "reflect", "symmetric" or "zero". Defaults to "reflect". """ self.input_data = data self.wavelet: pywt.Wavelet = _as_wavelet(wavelet) self.mode = mode self.data = {} if max_level is None: self.max_level = pywt.dwt_max_level( min(self.input_data.shape[-2:]), self.wavelet.dec_len ) else: self.max_level = max_level self._recursive_dwt2d(self.input_data, level=0, path="")
[docs] def get_level(self, level: int) -> List[str]: """Return the graycodes for a given level. Args: level (int): The required depth of the tree. Returns: list: A list with the node names. """ return self._get_graycode_order(level)
def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> List[str]: graycode_order = [x, y] for _ in range(level - 1): graycode_order = [x + path for path in graycode_order] + [ y + path for path in graycode_order[::-1] ] return graycode_order def _recursive_dwt2d( self, data: jnp.ndarray, level: int, path: str, ) -> None: self.data[path] = data if level < self.max_level: result_a, (result_h, result_v, result_d) = wavedec2( data, self.wavelet, 1, mode=self.mode ) # assert for type checking assert not isinstance(result_a, tuple) self._recursive_dwt2d(result_a, level + 1, path + "a") self._recursive_dwt2d(result_h, level + 1, path + "h") self._recursive_dwt2d(result_v, level + 1, path + "v") self._recursive_dwt2d(result_d, level + 1, path + "d") else: self.data[path] = data