Source code for jaxwt.utils

"""Various utility functions."""

# -*- coding: utf-8 -*-
from collections import namedtuple
from typing import List, Tuple, Union

import jax.numpy as jnp
import pywt

__all__ = ["flatten_2d_coeff_lst"]

Wavelet = namedtuple("Wavelet", ["dec_lo", "dec_hi", "rec_lo", "rec_hi"])


[docs]def flatten_2d_coeff_lst( coeff_list_2d: List[ Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]] ], flatten_arrays: bool = True, ) -> List[jnp.ndarray]: """Flattens a list of array tuples into a single list. Args: coeff_list_2d (list): A pywt-style coefficient list. flatten_arrays (bool): If true, 2d array are flattened. Defaults to True. Returns: list: A single 1-d list with all original elements. """ flat_coeff_lst = [] for coeff in coeff_list_2d: if isinstance(coeff, tuple): for c in coeff: if flatten_arrays: flat_coeff_lst.append(c.flatten()) else: flat_coeff_lst.append(c) else: if flatten_arrays: flat_coeff_lst.append(coeff.flatten()) else: flat_coeff_lst.append(coeff) return flat_coeff_lst
def _as_wavelet(wavelet: Union[Wavelet, str]) -> pywt.Wavelet: """Ensure the input argument to be a pywt wavelet compatible object. Args: wavelet (Wavelet or str): The input argument, which is either a pywt wavelet compatible object or a valid pywt wavelet name string. Returns: pywt.Wavelet: the input wavelet object or the pywt wavelet object described by the input str. """ if isinstance(wavelet, str): return pywt.Wavelet(wavelet) else: return wavelet