Source code for jaxwt.conv_fwt

"""Convolution based fast wavelet transforms."""

# -*- coding: utf-8 -*-

# Created on Thu Jun 11 2020
# Copyright (c) 2020 Moritz Wolter
#
from typing import Any, List, Optional, Tuple, Union

import jax
import jax.lax
import jax.numpy as jnp
import pywt

from .utils import _as_wavelet, _check_if_array, _fold_axes, _unfold_axes


def _preprocess_array_dec1d(
    data: jnp.ndarray,
) -> Tuple[jnp.ndarray, Union[List[int], None]]:
    ds = None
    if len(data.shape) == 1:
        # add channel and batch dimension.
        data = jnp.expand_dims(data, (0, 1))
    elif len(data.shape) == 2:
        # add the channel dimension.
        data = jnp.expand_dims(data, 1)
    else:
        data, ds = _fold_axes(data, 1)
        data = jnp.expand_dims(data, 1)
    return data, ds


def _postprocess_result_list_dec1d(
    result_lst: List[jnp.ndarray], ds: List[int]
) -> List[jnp.ndarray]:
    unfold_list = []
    for fres in result_lst:
        unfold_list.append(_unfold_axes(fres, ds, 1))
    return unfold_list


def _preprocess_result_list_rec1d(
    result_lst: List[jnp.ndarray],
) -> Tuple[List[jnp.ndarray], List[int]]:
    fold_coeffs = []
    ds = list(_check_if_array(result_lst[0]).shape)
    for uf_coeff in result_lst:
        f_coeff, _ = _fold_axes(uf_coeff, 1)
        fold_coeffs.append(f_coeff)
    return fold_coeffs, ds


[docs] def wavedec( data: jnp.ndarray, wavelet: Union[pywt.Wavelet, str], mode: str = "symmetric", level: Optional[int] = None, axis: int = -1, precision: str = "highest", ) -> List[jnp.ndarray]: """Compute the analysis wavelet transform of the last dimension. Args: data (jnp.ndarray): Input data array. I.e. of shape [batch, time]. wavelet (pywt.Wavelet): A wavelet name-string or a wavelet object containing the wavelet filter arrays. Check pywt.wavelist() for a list of options. mode (str): The padding used to extend the input signal. Choose reflect, symmetric or zero. Defaults to symmetric. level (int): Max scale level to be used, of none as many levels as possible are used. Defaults to None. axis (int): Compute the transform over this axis instead of the last one. Defaults to -1. precision (str): For desired precision, choose "fastest", "high" or "highest". Defaults to "highest". Returns: list: List containing the wavelet coefficients. The coefficients are in ``pywt`` order: [cA_n, cD_n, cD_n-1, …, cD2, cD1]. A denotes approximation and D detail coefficients. Raises: ValueError: If the axis argument is not an integer. Examples: >>> import pywt >>> import jaxwt as jwt >>> import jax.numpy as jnp >>> # generate an input of even length. >>> data = jnp.array([0., 1., 2., 3, 4, 5, 5, 4, 3, 2, 1, 0]) >>> jwt.wavedec(data, wavelet=pywt.Wavelet('haar'), level=2) """ if axis != -1: if isinstance(axis, int): data = data.swapaxes(axis, -1) else: raise ValueError("wavedec transforms a single axis only.") wavelet = _as_wavelet(wavelet) data, ds = _preprocess_array_dec1d(data) dec_lo, dec_hi, _, _ = _get_filter_arrays(wavelet, flip=True, dtype=data.dtype) filt_len = dec_lo.shape[-1] filt = jnp.stack([dec_lo, dec_hi], 0) if level is None: level = pywt.dwt_max_level(data.shape[-1], filt_len) result_list = [] res_lo = data for _ in range(level): res_lo = _fwt_pad(res_lo, len(wavelet.dec_lo), mode=mode) res = jax.lax.conv_general_dilated( lhs=res_lo, # lhs = NCH image tensor rhs=filt, # rhs = OIH conv kernel tensor padding="VALID", window_strides=[ 2, ], dimension_numbers=("NCH", "OIH", "NCH"), precision=jax.lax.Precision(precision), ) res_lo, res_hi = jnp.split(res, 2, 1) result_list.append(res_hi.squeeze(1)) result_list.append(res_lo.squeeze(1)) result_list.reverse() if ds: result_list = _postprocess_result_list_dec1d(result_list, ds) if axis != -1: swap = [] for coeff in result_list: swap.append(coeff.swapaxes(axis, -1)) result_list = swap return result_list
[docs] def waverec( coeffs: List[jnp.ndarray], wavelet: Union[pywt.Wavelet, str], axis: int = -1, precision: str = "highest", ) -> jnp.ndarray: """Reconstruct the original signal in one dimension. Args: coeffs (List[jnp.ndarray]): Wavelet coefficients, typically produced by the ``wavedec`` function. List entries of shape [batch_size, coefficients] work. wavelet (Union[pywt.Wavelet, str]): A string with a wavelet name or a wavelet object containing the wavelet filters used to evaluate the decomposition. axis (int): Transform this axis instead of the last one. Defaults to -1. precision (str): The desired precision, choose "fastest", "high" or "highest". Defaults to "highest". Returns: jnp.ndarray: Reconstruction of the original data. Raises: ValueError: If the axis argument is not an integer. Examples: >>> import pywt >>> import jaxwt as jwt >>> import jax.numpy as jnp >>> # generate an input of even length. >>> data = jnp.array([0., 1., 2., 3, 4, 5, 5, 4, 3, 2, 1, 0]) >>> transformed = jwt.wavedec(data, pywt.Wavelet('haar')) >>> jwt.waverec(transformed, pywt.Wavelet('haar')) """ if axis != -1: swap = [] if isinstance(axis, int): for coeff in coeffs: swap.append(coeff.swapaxes(axis, -1)) coeffs = swap else: raise ValueError("waverec transforms a single axis only.") ds = None if coeffs[0].ndim > 2: coeffs, ds = _preprocess_result_list_rec1d(coeffs) wavelet = _as_wavelet(wavelet) # unlike pytorch lax's transpose conv requires filter flips. _, _, rec_lo, rec_hi = _get_filter_arrays(wavelet, flip=True, dtype=coeffs[0].dtype) filt_len = rec_lo.shape[-1] filt = jnp.stack([rec_lo, rec_hi], 1) res_lo = coeffs[0] for c_pos, res_hi in enumerate(coeffs[1:]): # print('shapes', res_lo.shape, res_hi.shape) res_lo = jnp.stack([res_lo, res_hi], 1) res_lo = jax.lax.conv_transpose( lhs=res_lo, rhs=filt, padding="VALID", strides=[ 2, ], dimension_numbers=("NCH", "OIH", "NCH"), precision=jax.lax.Precision(precision), ) res_lo = _fwt_unpad(res_lo, filt_len, c_pos, coeffs) res_lo = res_lo.squeeze(1) if ds: res_lo = _unfold_axes(res_lo, ds, 1) if axis != -1: res_lo = res_lo.swapaxes(axis, -1) return res_lo
def _fwt_unpad( res_lo: jnp.ndarray, filt_len: int, c_pos: int, coeffs: List[jnp.ndarray] ) -> jnp.ndarray: padr = 0 padl = 0 if filt_len > 2: padr += (2 * filt_len - 3) // 2 padl += (2 * filt_len - 3) // 2 if c_pos < len(coeffs) - 2: pred_len = res_lo.shape[-1] - (padl + padr) nex_len = coeffs[c_pos + 2].shape[-1] if nex_len != pred_len: padl += 1 pred_len = res_lo.shape[-1] - padl if padl == 0: res_lo = res_lo[..., padr:] else: res_lo = res_lo[..., padr:-padl] return res_lo def _fwt_pad(data: jnp.ndarray, filt_len: int, mode: str = "reflect") -> jnp.ndarray: """Pad an input to ensure our fwts are invertible. Args: data (jnp.ndarray): The input array. filt_len (int): The length of the wavelet filters mode (str): How to pad. Defaults to "reflect". Returns: jnp.array: A padded version of the input data array. """ # pad to we see all filter positions and pywt compatability. # convolution output length: # see https://arxiv.org/pdf/1603.07285.pdf section 2.3: # floor([data_len - filt_len]/2) + 1 # should equal pywt output length # floor((data_len + filt_len - 1)/2) # => floor([data_len + total_pad - filt_len]/2) + 1 # = floor((data_len + filt_len - 1)/2) # (data_len + total_pad - filt_len) + 2 = data_len + filt_len - 1 # total_pad = 2*filt_len - 3 if mode == "zero": # translate pywt to numpy. mode = "constant" padr = (2 * filt_len - 3) // 2 padl = (2 * filt_len - 3) // 2 # pad to even singal length. if data.shape[-1] % 2 != 0: padr += 1 data = jnp.pad(data, [(0, 0)] * (data.ndim - 1) + [(padl, padr)], mode) return data def _get_filter_arrays( wavelet: pywt.Wavelet, flip: bool, dtype: jnp.dtype[Any] = jnp.float64 ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Extract the filter coefficients from an input wavelet object. Args: wavelet (pywt.Wavelet): A pywt-style input wavelet. flip (bool): If true flip the input coefficients. dtype: The desired precision. Defaults to jnp.float64 . Returns: tuple: The dec_lo, dec_hi, rec_lo and rec_hi filter coefficients as jax arrays. """ def create_array(filter: Union[List[float], jnp.ndarray]) -> jnp.ndarray: if flip: if type(filter) is jnp.ndarray: return jnp.expand_dims(jnp.flip(filter), 0) else: return jnp.expand_dims(jnp.array(filter[::-1]), 0) else: if type(filter) is jnp.ndarray: return jnp.expand_dims(filter, 0) else: return jnp.expand_dims(jnp.array(filter), 0) if isinstance(wavelet, str): wavelet = pywt.Wavelet(wavelet) dec_lo, dec_hi, rec_lo, rec_hi = wavelet.filter_bank elif type(wavelet) is pywt.Wavelet: dec_lo, dec_hi, rec_lo, rec_hi = wavelet.filter_bank else: dec_lo, dec_hi, rec_lo, rec_hi = wavelet dec_lo = create_array(dec_lo).astype(dtype) dec_hi = create_array(dec_hi).astype(dtype) rec_lo = create_array(rec_lo).astype(dtype) rec_hi = create_array(rec_hi).astype(dtype) return dec_lo, dec_hi, rec_lo, rec_hi