Source code for tvb.analyzers.wavelet

# -*- coding: utf-8 -*-
#
#
# TheVirtualBrain-Scientific Package. This package holds all simulators, and
# analysers necessary to run brain-simulations. You can use it stand alone or
# in conjunction with TheVirtualBrain-Framework Package. See content of the
# documentation-folder for more details. See also http://www.thevirtualbrain.org
#
# (c) 2012-2023, Baycrest Centre for Geriatric Care ("Baycrest") and others
#
# This program is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software Foundation,
# either version 3 of the License, or (at your option) any later version.
# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE.  See the GNU General Public License for more details.
# You should have received a copy of the GNU General Public License along with this
# program.  If not, see <http://www.gnu.org/licenses/>.
#
#
#   CITATION:
# When using The Virtual Brain for scientific publications, please cite it as explained here:
# https://www.thevirtualbrain.org/tvb/zwei/neuroscience-publications
#
#

"""
Calculate a wavelet transform on a TimeSeries datatype and return a
WaveletSpectrum datatype.

.. moduleauthor:: Stuart A. Knock <Stuart@tvb.invalid>
.. moduleauthor:: Andreas Spiegler <anspiegler@googlemail.com>
.. moduleauthor:: Marmaduke Woodman <marmaduke.woodman@univ-amu.fr>
.. moduleauthor:: Paula Sanz Leon <Paula@tvb.invalid>

"""

import numpy
import scipy.signal as signal
import tvb.datatypes.spectral as spectral
from tvb.basic.logger.builder import get_logger
from tvb.basic.neotraits.api import HasTraits, Attr, Range, Float, narray_describe
from tvb.simulator.backend.ref import ReferenceBackend

SUPPORTED_WAVELET_FUNCTIONS = ("morlet",)

log = get_logger(__name__)

"""
A module for calculating the wavelet transform of a TimeSeries object of TVB
and returning a WaveletSpectrum object. The sampling period and frequency
range of the result can be specified. The mother wavelet can also be
specified... (So far, only Morlet.)

References:
    .. [TBetal_1996] C. Tallon-Baudry et al, *Stimulus Specificity of
        Phase-Locked and Non-Phase-Locked 40 Hz Visual Responses in Human.*,
        J Neurosci 16(13):4240-4249, 1996.

    .. [Mallat_1999] S. Mallat, *A wavelet tour of signal processing.*,
        book, Academic Press, 1999.
"""


[docs] def compute_continuous_wavelet_transform(time_series, frequencies, sample_period, q_ratio, normalisation, mother): """ # type: (TimeSeries, Range, float, float, str, str) -> WaveletCoefficients Calculate the continuous wavelet transform of time_series. Parameters __________ time_series : TimeSeries The timeseries to which the wavelet is to be applied. frequencies : Range The frequency resolution and range returned. Requested frequencies are expected to be in kHz. sample_period : float The sampling period in ms of the computed wavelet spectrum. q_ratio : float NFC. Must be greater than 5. Ratios of the center frequencies to bandwidths. normalisation : str The type of normalisation for the resulting wavet spectrum. Default is 'energy', options are: 'energy'; 'gabor'. mother : str The mother wavelet function used in the transform. """ ts_shape = time_series.data.shape if frequencies.step == 0: log.warning("Frequency step can't be 0! Trying default step, 2e-3.") frequencies.step = 0.002 freqs = numpy.arange(frequencies.lo, frequencies.hi, frequencies.step) if (freqs.size == 0) or any(freqs <= 0.0): # TODO: Maybe should limit number of freqs... ~100 is probably a reasonable upper bound. log.warning("Invalid frequency range! Falling back to default.") log.debug("freqs") log.debug(narray_describe(freqs)) frequencies = Range(lo=0.008, hi=0.060, step=0.002) freqs = numpy.arange(frequencies.lo, frequencies.hi, frequencies.step) log.debug("freqs") log.debug(narray_describe(freqs)) # We need this to be kHz (see TVB-2946) sample_rate = time_series.sample_rate / 1000 # Duke: code below is as given by Andreas Spiegler, I've just wrapped # some of the original argument names nf = len(freqs) temporal_step = max((1, ReferenceBackend.iround(sample_period / time_series.sample_period_ms))) nt = int(numpy.ceil(ts_shape[0] / temporal_step)) if not isinstance(q_ratio, numpy.ndarray): new_q_ratio = q_ratio * numpy.ones((1, nf)) if numpy.nanmin(new_q_ratio) < 5: msg = "q_ratio must be not lower than 5 !" log.error(msg) raise Exception(msg) if numpy.nanmax(freqs) > sample_rate / 2.0: msg = "Sampling rate is too low for the requested frequency range !" log.error(msg) raise Exception(msg) # TODO: This isn't used, but min frequency seems like it should be important... Check with A.S. # fmin = 3.0 * numpy.nanmin(q_ratio) * sample_rate / numpy.pi / nt sigma_f = freqs / new_q_ratio sigma_t = 1.0 / (2.0 * numpy.pi * sigma_f) if normalisation == 'energy': Amp = 1.0 / numpy.sqrt(sample_rate * numpy.sqrt(numpy.pi) * sigma_t) elif normalisation == 'gabor': Amp = numpy.sqrt(2.0 / numpy.pi) / sample_rate / sigma_t coef_shape = (nf, nt, ts_shape[1], ts_shape[2], ts_shape[3]) coef = numpy.zeros(coef_shape, dtype=numpy.complex128) log.debug("coef") log.debug(narray_describe(coef)) scales = numpy.arange(0, nf, 1) for i in scales: f0 = freqs[i] SDt = sigma_t[(0, i)] A = Amp[(0, i)] x = numpy.arange(0, 4.0 * SDt * sample_rate, 1) / sample_rate wvlt = A * numpy.exp(-x ** 2 / (2.0 * SDt ** 2)) * numpy.exp(2j * numpy.pi * f0 * x) wvlt = numpy.hstack((numpy.conjugate(wvlt[-1:0:-1]), wvlt)) # util.self.log_debug_array(self.log, wvlt, "wvlt") for var in range(ts_shape[1]): for node in range(ts_shape[2]): for mode in range(ts_shape[3]): data = time_series.data[:, var, node, mode] wt = signal.convolve(data, wvlt, 'same') # util.self.log_debug_array(self.log, wt, "wt") res = wt[0::temporal_step] # NOTE: this is a horrible horrible quick hack (alas, a solution) to avoid broadcasting errors # when using dt and sample periods which are not powers of 2. coef[i, :, var, node, mode] = res if len(res) == nt else res[:coef.shape[1]] log.debug("coef") log.debug(narray_describe(coef)) spectra = spectral.WaveletCoefficients( source=time_series, mother=mother, sample_period=sample_period, frequencies=frequencies.to_array(), normalisation=normalisation, q_ratio=q_ratio, array_data=coef) return spectra