Source code for jaxwt.conv_fwt_2d

"""Two dimensional convolution based fast wavelet transforms."""
#
# Created on Thu Jun 12 2020
# Copyright (c) 2020 Moritz Wolter
#
from typing import List, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import pywt
from jax.config import config

from .conv_fwt import _get_filter_arrays
from .utils import Wavelet

config.update("jax_enable_x64", True)


[docs]def wavedec2( data: jnp.ndarray, wavelet: Wavelet, level: Optional[int] = None, mode: str = "reflect", ) -> List[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]]: """Compute the two dimensional wavelet analysis transform on the last two dimensions of the input data array. Args: data (jnp.array): Jax array containing the data to be transformed. Assumed shape: [batch size, hight, width]. wavelet (Wavelet): A namedtouple containing the filters for the transformation. level (int): The max level to be used, if not set as many levels as possible will be used. Defaults to None. mode (str): The desired padding mode. Choose reflect or symmetric. Defaults to reflect. Returns: list: The wavelet coefficients in a nested list. The coefficients are in pywt order. That is: [cAn, (cHn, cVn, cDn), … (cH1, cV1, cD1)]. A denotes approximation, H horizontal, V vertical and D diagonal coefficients. Examples: >>> import pywt, scipy.misc >>> import jaxwt as jwt >>> import jax.numpy as jnp >>> face = jnp.transpose(scipy.misc.face(), [2, 0, 1]).astype(jnp.float64) >>> jwt.wavedec2(face, pywt.Wavelet("haar"), level=2, mode="reflect") """ data = jnp.expand_dims(data, 1) dec_lo, dec_hi, _, _ = _get_filter_arrays(wavelet, flip=True) dec_filt = construct_2d_filt(lo=dec_lo, hi=dec_hi) if mode == "zero": # translate pywt to numpy. mode = "constant" if level is None: level = pywt.dwtn_max_level( [data.shape[-1], data.shape[-2]], pywt.Wavelet("MyWavelet", wavelet) ) result_lst: List[ Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]] ] = [] res_ll = data for _ in range(level): res_ll = _fwt_pad2d(res_ll, len(wavelet), mode=mode) res = jax.lax.conv_general_dilated( lhs=res_ll, # lhs = NCHw image tensor rhs=dec_filt, # rhs = OIHw conv kernel tensor padding="VALID", window_strides=[2, 2], dimension_numbers=("NCHW", "OIHW", "NCHW"), ) res_ll, res_lh, res_hl, res_hh = jnp.split(res, 4, 1) result_lst.append((res_lh.squeeze(1), res_hl.squeeze(1), res_hh.squeeze(1))) result_lst.append(res_ll.squeeze(1)) result_lst.reverse() return result_lst
[docs]def waverec2( coeffs: List[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]], wavelet: Wavelet, ) -> jnp.ndarray: """Compute a two dimensional synthesis wavelet transfrom. Use it to reconstruct the original input image from the wavelet coefficients. Args: coeffs (list): The input coefficients, typically the output of wavedec2. wavelet (Wavelet): The named tuple contining the filters used to compute the analysis transform. Returns: jnp.array: Reconstruction of the original input data array of shape [batch, height, width]. Raises: ValueError: If `coeffs` is not in the shape as it is returned from `wavedec2`. Example: >>> import pywt, scipy.misc >>> import jaxwt as jwt >>> import jax.numpy as jnp >>> face = jnp.transpose(scipy.misc.face(), [2, 0, 1]).astype(jnp.float64) >>> transformed = jwt.wavedec2(face, pywt.Wavelet("haar"), level=2, mode="reflect") >>> jwt.waverec2(transformed, pywt.Wavelet("haar")) """ if not isinstance(coeffs[0], jnp.ndarray): raise ValueError( "First element of coeffs must be the approximation coefficient tensor." ) _, _, rec_lo, rec_hi = _get_filter_arrays(wavelet, flip=True, dtype=coeffs[0].dtype) filt_len = rec_lo.shape[-1] rec_filt = construct_2d_filt(lo=rec_lo, hi=rec_hi) rec_filt = jnp.transpose(rec_filt, [1, 0, 2, 3]) res_ll = jnp.expand_dims(coeffs[0], 1) for c_pos, res_lh_hl_hh in enumerate(coeffs[1:]): res_ll = jnp.concatenate( [ res_ll, jnp.expand_dims(res_lh_hl_hh[0], 1), jnp.expand_dims(res_lh_hl_hh[1], 1), jnp.expand_dims(res_lh_hl_hh[2], 1), ], 1, ) res_ll = jax.lax.conv_transpose( lhs=res_ll, rhs=rec_filt, padding="VALID", strides=[2, 2], dimension_numbers=("NCHW", "OIHW", "NCHW"), ) # remove the padding padl = (2 * filt_len - 3) // 2 padr = (2 * filt_len - 3) // 2 padt = (2 * filt_len - 3) // 2 padb = (2 * filt_len - 3) // 2 if c_pos < len(coeffs) - 2: pred_len = res_ll.shape[-1] - (padl + padr) next_len = coeffs[c_pos + 2][0].shape[-1] pred_len2 = res_ll.shape[-2] - (padt + padb) next_len2 = coeffs[c_pos + 2][0].shape[-2] if next_len != pred_len: padr += 1 pred_len = res_ll.shape[-1] - (padl + padr) assert ( next_len == pred_len ), "padding error, please open an issue on github " if next_len2 != pred_len2: padb += 1 pred_len2 = res_ll.shape[-2] - (padt + padb) assert ( next_len2 == pred_len2 ), "padding error, please open an issue on github " # print('padding', padt, padb, padl, padr) if padt > 0: res_ll = res_ll[..., padt:, :] if padb > 0: res_ll = res_ll[..., :-padb, :] if padl > 0: res_ll = res_ll[..., padl:] if padr > 0: res_ll = res_ll[..., :-padr] return res_ll.squeeze(1)
[docs]def construct_2d_filt(lo: jnp.ndarray, hi: jnp.ndarray) -> jnp.ndarray: """Construct 2d filters from 1d inputs using outer products. Args: lo (jnp.array): 1d lowpass input filter of size [1, length]. hi (jnp.array): 1d highpass input filter of size [1, length]. Returns jnp.array: 2d filter arrays of shape [4, 1, length, length]. """ ll = jnp.outer(lo, lo) lh = jnp.outer(hi, lo) hl = jnp.outer(lo, hi) hh = jnp.outer(hi, hi) filt = jnp.stack([ll, lh, hl, hh], 0) filt = jnp.expand_dims(filt, 1) return filt
def _fwt_pad2d(data: jnp.ndarray, filt_len: int, mode: str = "reflect") -> jnp.ndarray: padr = 0 padl = 0 padt = 0 padb = 0 if filt_len > 2: # we pad half of the total requried padding on each side. padr += (2 * filt_len - 3) // 2 padl += (2 * filt_len - 3) // 2 padt += (2 * filt_len - 3) // 2 padb += (2 * filt_len - 3) // 2 # pad to even singal length. if data.shape[-1] % 2 != 0: padr += 1 if data.shape[-2] % 2 != 0: padb += 1 data = jnp.pad(data, ((0, 0), (0, 0), (padt, padb), (padl, padr)), mode) return data