Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2365bad
split out WaveletBasis and EZ transforms into separate submodules
paulthebaker Feb 1, 2018
cba7cf2
import new submodules
paulthebaker Feb 1, 2018
d8f6869
use mathjax in docs
paulthebaker Feb 1, 2018
d328b7b
rewrite Morlet docstrings
paulthebaker Feb 1, 2018
cff9815
skeletal Wavelet base class
paulthebaker Feb 1, 2018
b83fb83
Merge branch 'master' into ptb-restruct
paulthebaker Feb 1, 2018
4c8e8cc
new docs sections
paulthebaker Feb 1, 2018
225bdaa
flake8 fixes
paulthebaker Feb 1, 2018
9478d7c
rename
paulthebaker Feb 1, 2018
7828b9c
use new name (basis)
paulthebaker Feb 1, 2018
c84c504
use correct name
paulthebaker Feb 1, 2018
9fdeee8
rewrite easy.py docstrings
paulthebaker Feb 1, 2018
16b9969
"sample cadence" not "sampling time"
paulthebaker Feb 1, 2018
dc2df67
use correct names
paulthebaker Feb 1, 2018
5bc7b89
docstrings
paulthebaker Feb 1, 2018
8f144e0
rewrite docstrings
paulthebaker Feb 2, 2018
c34c1dc
docstrings for PaulWave
paulthebaker Feb 2, 2018
d9bc3d2
basic usage docs
paulthebaker Feb 2, 2018
1a25f17
one python badge
paulthebaker Feb 2, 2018
041fc8e
update docs
paulthebaker Feb 2, 2018
f076902
convert ipynb to rst
paulthebaker Feb 2, 2018
448881e
nb checkpoints
paulthebaker Feb 2, 2018
54d69d6
remove trailing cells
paulthebaker Feb 2, 2018
dd1d2de
remove trailing cells
paulthebaker Feb 2, 2018
181f9e6
docs
paulthebaker Apr 24, 2018
8e3276a
move __call__ to base class
paulthebaker Aug 20, 2018
6a4a2f5
Merge branch 'master' into ptb-restruct
paulthebaker Aug 20, 2018
b583676
whitespace
paulthebaker Aug 20, 2018
a3741eb
do cwt in fourier domain
paulthebaker Aug 20, 2018
2776987
Merge branch 'ptb-restruct' of https://github.com/paulthebaker/ceedub…
paulthebaker Aug 20, 2018
2b00b19
fix normalization
paulthebaker Aug 20, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ __pycache__/
*.py[cod]
*$py.class

# IPython notebooks
.ipynb_checkpoints/

# C extensions
*.so

Expand Down
6 changes: 1 addition & 5 deletions README.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
======
ceedub
======
|py27| |py35| |py36|

.. |py27| image:: https://img.shields.io/badge/python-2.7-blue.svg
.. |py35| image:: https://img.shields.io/badge/python-3.5-blue.svg
.. |py36| image:: https://img.shields.io/badge/python-3.6-blue.svg
.. image:: https://img.shields.io/badge/python-2.7%2C%203.5%2C%203.6-blue.svg

.. image:: https://img.shields.io/pypi/v/ceedub.svg
:target: https://pypi.python.org/pypi/ceedub
Expand Down
8 changes: 7 additions & 1 deletion ceedub/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
# -*- coding: utf-8 -*-
# flake8: noqa ignore=F401
"""Continuous wavelet transform and support functions
based on Torrence and Compo 1998 (T&C)

from .wavelet import WaveletBasis, cwt, icwt, cwtfreq
(http://paos.colorado.edu/research/wavelets/bams_79_01_0061.pdf)
"""

from .easy import cwt, icwt, cwtfreq
from .wavelet import MorletWave, PaulWave
from .basis import WaveletBasis

__author__ = """Paul T. Baker"""
__email__ = 'paultbaker@gmail.com'
Expand Down
230 changes: 230 additions & 0 deletions ceedub/basis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# -*- coding: utf-8 -*-
"""WaveletBasis object containing transform methods
"""

from __future__ import (absolute_import, division,
print_function, unicode_literals)

import numpy as np

from .wavelet import MorletWave


class WaveletBasis(object):
"""An object containing a basis for wavelet transforms

The basis is used for forward and inverse transforms of data using
the same sample rate and frequency scales. At initialization given
``N``, ``dt``, and ``dj``, the scales will be computed from the
``_get_scales()`` method based on the Nyquist period of the wavelet
and the length of the data.

See T&C section 3.f for more information about how scales are choosen.
"""
def __init__(self, wavelet=None, N=None, dt=1, dj=1/16):
"""initialize ``WaveletBasis`` object

:param wavelet:
Wavelet basis function which takes two arguments: ``t`` and
``s``. ``t`` is the time to evaluate the wavelet function.
``s`` is the scale or width parameter. The wavelet function
should be normalized to unit weight at scale=1, and have
zero mean.
:param N:
length of time domain data that will be transformed
:param dt:
sample cadence of data, needed for normalization of transforms
:param dj:
scale step size, used to determine the frequency scales to use
for the transform

Note: the wavelet function used to generate a basis has different
requirements than ``scipy.signal.cwt``. The Ricker and Morlet
wavelet functions provided in ``scipy.signal`` are incompatible
with this function. Instances of the ``MorletWave`` and ``PaulWave``
callable objects provided in ``ceedub.wavelet`` can be used.
"""
if wavelet is None:
wavelet = MorletWave() # default to Morlet, w0=6
if not isinstance(N, int):
raise TypeError("N must be an integer")

self._wavelet = wavelet
self._dt = dt
self._dj = dj
self._N = N

self._inv_root_scales = 1./np.sqrt(self.scales)

# don't provide setters for properties!
# all are determined at creation and frozen!
@property
def wavelet(self):
"""wavelet basis function

:param t:
time or array of times to calculate wavelet
:param s:
scale parameter for wavelet
"""
return self._wavelet

@property
def dt(self):
return self._dt

@property
def dj(self):
return self._dj

@property
def N(self):
return self._N

@property
def s0(self):
if not hasattr(self, '_s0'):
try:
self._s0 = self.wavelet.nyquist_scale(self.dt)
except AttributeError:
self._s0 = 2*self.dt
return self._s0

@property
def scales(self):
if not hasattr(self, '_scales'):
self._scales = self._get_scales()
return self._scales

@property
def M(self):
if not hasattr(self, '_M'):
self._M = len(self.scales)
return self._M

@property
def times(self):
"""sample times of data"""
if not hasattr(self, '_times'):
self._times = np.arange(self.N) * self.dt
return self._times

@property
def freqs(self):
if not hasattr(self, '_freqs'):
try:
self._freqs = 1./self.wavelet.fourier_period(self.scales)
except AttributeError:
self._freqs = 1./self.scales
return self._freqs

def cwt(self, tdat):
"""compute continuous wavelet transfomrm

:param tdat:
shape ``(N,)`` array of real, time domain data

:return wdat:
shape ``(M,N)`` array of complex, wavelet domain data.
``M`` is the number of scales used in the transform, and ``N``
is the length of the input time domain data.

Uses the wavelet function and scales of the WaveletBasis. The
transform is calculated via FFT convolution as described in T&C.
The FFT convolution is computed once at each wavelet scale,
determining the frequecny resolution of the output.
"""
if len(tdat) != self.N:
raise ValueError("tdat is not length N={:d}".format(self.N))

fdat = np.fft.fft(tdat)
ws = 2*np.pi * np.fft.fftfreq(len(tdat), d=self.dt)
dW = ws[1] - ws[0]
waves = [self.wavelet.freq(ws, s) for s in self.scales]

wdat = np.fft.ifft(waves * fdat) # convolution theorem
norms = np.sqrt(self.scales * dW * self.N)

return np.einsum('i,ij->ij', norms, wdat)

def icwt(self, wdat):
"""compute the inverse continuous wavelet transform

:param wdat:
shape ``(M,N)`` array of complex, wavelet domain data.
``M`` is the number of frequency scales, and ``N`` is
the number of time samples.

:returns tdat:
shape ``(N,)`` array of real, time domain data

The inverse continuous wavelet transform is computed following
T&C section 3.i. Uses the wavelet function and scales of the
parent WaveletBasis.
"""
if not hasattr(self, '_recon_norm'):
self._recon_norm = self._get_recon_norm()
M = self.M
N = self.N
if wdat.shape != (M, N):
raise ValueError("wdat is not shape ({0:d},{1:d})".format(M, N))
irs = self._inv_root_scales
tdat = np.einsum('ij,i->j', np.real(wdat), irs)
tdat *= self._recon_norm
return tdat

def _get_scales(self):
r"""determine the frequency scales used in the transform

:return scales:
array of scale parameters, s, for use in ``cwt``

The frequency scales have log2 frequency spacing. They are
chosen such that :math:`s_0` is the 'smallest' scale, :math:`dj`
is the scale step size (in log-space), and :math:`\log_2(N)` is
the number of octaves. Traditionally,

.. math:
s_j &= s_0 \cdot 2^{j\cdot dj}, \text{ for } j \in [0,J] \\
J &= \log_2(N) / dj

If the wavelet used contains a ``nyquist_scale()`` method, :math:`s_0`
will correspond to the Nyquist frequency. The largest scale is
has a frequency given by the observation time: :math:`1/(2 T_{obs}).

In practice the inverse transforms are more accurate when scales
outside of the usual Fourier frequencies are used. We use an extra
eight octaves, the scales for :math:`j\in [-4, J+4]`.
"""
N = self.N
dj = self.dj
s0 = self.s0
Noct = np.log2(N)+1 # number of scale octaves
J = int(Noct / dj) # total number of scales
s = [s0 * 2**(j * dj) for j in range(-4, J)]
return np.array(s)

def _get_recon_norm(self):
r"""compute the normalization factor for ICWT

This is not :math:`C_\delta` from T&C, this is a normalization
constant :math:`A`, such that in the ICWT calculation

.. math: A\, \sum_k \psi(w_k) = tdat.

This constant eliminates some factors which explicitly cancel in
later calculations, for example
:math:`\frac{dj\cdot dt^{1/2}}{\psi_0}`.
"""
N = self.N
dt = self.dt
scales = self.scales
Psi_f = self.wavelet.freq # f-domain wavelet as f(w_k, s)
w_k = 2*np.pi * np.fft.rfftfreq(N, dt) # Fourier freqs

W_d = np.zeros_like(scales)
for ii, sc in enumerate(scales):
norm = np.sqrt(2*np.pi / dt)
W_d[ii] = np.sum(Psi_f(w_k, s=sc).conj()) * norm
W_d /= N
return 1/np.sum(np.real(W_d))
68 changes: 68 additions & 0 deletions ceedub/easy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
"""easy to use methods for default CWT and ICWT
"""

from __future__ import (absolute_import, division,
print_function, unicode_literals)

from .basis import WaveletBasis


def cwt(tdat, dt=1):
"""Compute continuous wavelet transform, using default ``WaveletBasis``.

:param tdat:
shape ``(N,)`` array of real, time domain data
:param dt:
sample cadence of data, needed for normalization of transforms

:return wdat:
shape ``(M,N)`` array of complex, wavelet domain data. ``M`` is
the number of scales used in the transform, and ``N`` is the
length of the input time domain data.

If you plan on doing several CWTs in the same basis you should
consider initializing a ``WaveletBasis`` object and using:
``WaveletBasis.cwt()``.
"""
WB = WaveletBasis(N=len(tdat), dt=dt)
return WB.cwt(tdat)


def icwt(wdat, dt=1):
"""Compute inverse continuous wavelet transform, using default
``WaveletBasis``.

:param wdat:
shape ``(M,N)`` array of complex, wavelet domain data. ``M`` is
the number of frequency scales, and ``N`` is the number of time
samples.

:return tdat:
shape ``(N,)`` array of real, time domain data

If the forward transform was performed in a different basis, then this
function will give incorrect output!

If you plan on doing several ICWTs in the same basis you should seriously
consider initializing a ``WaveletBasis`` object and using:
``WaveletBasis.cwt()`` and ``WaveletBasis.icwt()``.
"""
WB = WaveletBasis(N=wdat.shape[1], dt=dt)
return WB.icwt(wdat)


def cwtfreq(N, dt=1):
"""Output the Fourier frequencies of the scales used in the default
``WaveletBasis``.

:param N:
number of time samples in the time domain data.
:param dt:
sample cadence of data

:return freqs:
shape ``(M,)`` array of frequencies
"""
WB = WaveletBasis(N=N, dt=dt)
return WB.freqs
Loading