Source code for jaxwt.conv_fwt_2d

"""Two dimensional convolution based fast wavelet transforms."""

#
# Created on Thu Jun 12 2020
# Copyright (c) 2020 Moritz Wolter
#
from functools import partial
from typing import List, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import pywt

from .conv_fwt import _get_filter_arrays
from .utils import (
    _adjust_padding_at_reconstruction,
    _as_wavelet,
    _check_axes_argument,
    _check_if_array,
    _fold_axes,
    _swap_axes,
    _undo_swap_axes,
    _unfold_axes,
)


def _preprocess_array_dec2d(
    data: jnp.ndarray,
) -> Tuple[jnp.ndarray, Union[List[int], None]]:
    ds = None
    if len(data.shape) == 2:
        data = jnp.expand_dims(data, (0, 1))
    elif len(data.shape) == 3:
        data = jnp.expand_dims(data, 1)
    elif len(data.shape) >= 4:
        data, ds = _fold_axes(data, 2)
        data = jnp.expand_dims(data, 1)
    elif len(data.shape) == 1:
        raise ValueError("More than one input dimension required.")
    return data, ds


[docs] def wavedec2( data: jnp.ndarray, wavelet: Union[pywt.Wavelet, str], mode: str = "symmetric", level: Optional[int] = None, axes: Tuple[int, int] = (-2, -1), precision: str = "highest", ) -> List[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]]: """Compute the two-dimensional wavelet analysis transform on the last two dimensions of the input data array. Args: data (jnp.ndarray): Jax array containing the data to be transformed. A possible input shape would be [batch size, height, width]. wavelet (Union[pywt.Wavelet, str]): A wavelet object or wavelet string for the transformation. Check pywt.wavelist() for a list of options. mode (str): The desired padding mode. Choose "reflect", "symmetric" or "zero". Defaults to symmetric. level (int): The max level to be used, if not set as many levels as possible will be used. Defaults to None. axes (Tuple[int, int]): Compute the transform over these axes instead of the last two. Defaults to (-2, -1). precision (str): For the desired precision, choose "fastest", "high" or "highest". Defaults to "highest". Returns: list: The wavelet coefficients of shape [batch, height, width] in a nested list. The coefficients are in pywt order. That is: [cAn, (cHn, cVn, cDn), … (cH1, cV1, cD1)]. A denotes approximation, H horizontal, V vertical and D diagonal coefficients. Raises: ValueError: If the axes tuple does not have two elements or contains a repetition. Examples: >>> import pywt, scipy.datasets >>> import jaxwt as jwt >>> import jax.numpy as jnp >>> face = jnp.transpose(scipy.datasets.face(), [2, 0, 1]) >>> face = face.astype(jnp.float32) >>> jwt.wavedec2(face, "haar", level=2) """ wavelet = _as_wavelet(wavelet) if tuple(axes) != (-2, -1): if len(axes) != 2: raise ValueError("2d transforms work with two axes.") else: data = _swap_axes(data, list(axes)) data, ds = _preprocess_array_dec2d(data) dec_lo, dec_hi, _, _ = _get_filter_arrays(wavelet, flip=True, dtype=data.dtype) dec_filt = _construct_2d_filt(lo=dec_lo, hi=dec_hi) if level is None: level = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet) result_list: List[ Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]] ] = [] res_ll = data for _ in range(level): res_ll = _fwt_pad2d(res_ll, len(wavelet), mode=mode) res = jax.lax.conv_general_dilated( lhs=res_ll, # lhs = NCHw image tensor rhs=dec_filt, # rhs = OIHw conv kernel tensor padding="VALID", window_strides=[2, 2], dimension_numbers=("NCHW", "OIHW", "NCHW"), precision=jax.lax.Precision(precision), ) res_ll, res_lh, res_hl, res_hh = jnp.split(res, 4, 1) result_list.append((res_lh.squeeze(1), res_hl.squeeze(1), res_hh.squeeze(1))) result_list.append(res_ll.squeeze(1)) result_list.reverse() if ds: _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) result_list = jax.tree_util.tree_map(_unfold_axes2, result_list) if axes != (-2, -1): to_tree = partial(_undo_swap_axes, axes=axes) result_list = jax.tree_util.tree_map(to_tree, result_list) return result_list
[docs] def waverec2( coeffs: List[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]], wavelet: Union[pywt.Wavelet, str], axes: Tuple[int, int] = (-2, -1), precision: str = "highest", ) -> jnp.ndarray: """Compute a two-dimensional synthesis wavelet transform. Use it to reconstruct the original input image from the wavelet coefficients. Args: coeffs (list): The input coefficients, typically the output of wavedec2. wavelet (Union[pywt.Wavelet, str]): The wavelet to use for the synthesis transform. axes (Tuple[int, int]): Compute the transform over these axes instead of the last two. Defaults to (-2, -1). precision (str): For desired precision, choose "fastest", "high" or "highest". Defaults to "highest". Raises: ValueError: If the axes tuple does not have two elements or contains a repetition. Returns: jnp.ndarray: Reconstruction of the original input data array of shape [batch, height, width]. Example: >>> import pywt, scipy.datasets >>> import jaxwt as jwt >>> import jax.numpy as jnp >>> face = jnp.transpose(scipy.datasets.face(), [2, 0, 1]) >>> face = face.astype(jnp.float32) >>> transformed = jwt.wavedec2(face, "haar") >>> jwt.waverec2(transformed, "haar") """ wavelet = _as_wavelet(wavelet) if tuple(axes) != (-2, -1): if len(axes) != 2: raise ValueError("2d transforms work with two axes.") else: _check_axes_argument(list(axes)) to_tree = partial(_swap_axes, axes=list(axes)) coeffs = jax.tree_util.tree_map(to_tree, coeffs) ds = None if _check_if_array(coeffs[0]).ndim > 3: # TODO: check this! ds = list(_check_if_array(coeffs[0]).shape) _fold_axes_keep2 = partial(_fold_axes, keep_no=2) _tree_fold = lambda array: _fold_axes_keep2(array)[0] # noqa: E731 fcoeffs = jax.tree_util.tree_map(_tree_fold, coeffs) else: fcoeffs = coeffs _, _, rec_lo, rec_hi = _get_filter_arrays( wavelet, flip=True, dtype=_check_if_array(fcoeffs[0]).dtype ) filt_len = rec_lo.shape[-1] rec_filt = _construct_2d_filt(lo=rec_lo, hi=rec_hi) rec_filt = jnp.transpose(rec_filt, [1, 0, 2, 3]) res_ll = jnp.expand_dims(_check_if_array(fcoeffs[0]), 1) for c_pos, res_lh_hl_hh in enumerate(fcoeffs[1:]): res_ll = jnp.concatenate( [ res_ll, jnp.expand_dims(res_lh_hl_hh[0], 1), jnp.expand_dims(res_lh_hl_hh[1], 1), jnp.expand_dims(res_lh_hl_hh[2], 1), ], 1, ) res_ll = jax.lax.conv_transpose( lhs=res_ll, rhs=rec_filt, padding="VALID", strides=[2, 2], dimension_numbers=("NCHW", "OIHW", "NCHW"), precision=jax.lax.Precision(precision), ) # remove the padding padl = (2 * filt_len - 3) // 2 padr = (2 * filt_len - 3) // 2 padt = (2 * filt_len - 3) // 2 padb = (2 * filt_len - 3) // 2 if c_pos < len(fcoeffs) - 2: padr, padl = _adjust_padding_at_reconstruction( res_ll.shape[-1], fcoeffs[c_pos + 2][0].shape[-1], padr, padl ) padb, padt = _adjust_padding_at_reconstruction( res_ll.shape[-2], fcoeffs[c_pos + 2][0].shape[-2], padb, padt ) if padt > 0: res_ll = res_ll[..., padt:, :] if padb > 0: res_ll = res_ll[..., :-padb, :] if padl > 0: res_ll = res_ll[..., padl:] if padr > 0: res_ll = res_ll[..., :-padr] res_ll = res_ll.squeeze(1) if ds: res_ll = _unfold_axes(res_ll, ds, 2) if axes != (-2, -1): res_ll = _undo_swap_axes(res_ll, list(axes)) return res_ll
def _construct_2d_filt(lo: jnp.ndarray, hi: jnp.ndarray) -> jnp.ndarray: """Construct 2d filters from 1d inputs using outer products. Args: lo (jnp.ndarray): 1d lowpass input filter of size [1, length]. hi (jnp.ndarray): 1d highpass input filter of size [1, length]. Returns jnp.array: 2d filter arrays of shape [4, 1, length, length]. """ ll = jnp.outer(lo, lo) lh = jnp.outer(hi, lo) hl = jnp.outer(lo, hi) hh = jnp.outer(hi, hi) filt = jnp.stack([ll, lh, hl, hh], 0) filt = jnp.expand_dims(filt, 1) return filt def _fwt_pad2d( data: jnp.ndarray, filt_len: int, mode: str = "symmetric" ) -> jnp.ndarray: padr = 0 padl = 0 padt = 0 padb = 0 if filt_len > 2: # we pad half of the total requried padding on each side. padr += (2 * filt_len - 3) // 2 padl += (2 * filt_len - 3) // 2 padt += (2 * filt_len - 3) // 2 padb += (2 * filt_len - 3) // 2 # pad to even singal length. if data.shape[-1] % 2 != 0: padr += 1 if data.shape[-2] % 2 != 0: padb += 1 if mode == "zero": # translate pywt to numpy. mode = "constant" data = jnp.pad(data, ((0, 0), (0, 0), (padt, padb), (padl, padr)), mode) return data