Source code for pumpp.sampler

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
Data subsampling
================
.. autosummary::
    :toctree: generated/

    Sampler
'''

from itertools import count

import numpy as np

__all__ = ['Sampler']


[docs]class Sampler(object): '''Generate samples from a pumpp data dict. Attributes ---------- n_samples : int or None the number of samples to generate. If `None`, generate indefinitely. duration : int > 0 the duration (in frames) of each sample ops : one or more pumpp.feature.FeatureExtractor or pumpp.task.BaseTaskTransformer The operators to include when sampling data. Examples -------- >>> # Set up the parameters >>> sr, n_fft, hop_length = 22050, 512, 2048 >>> # Instantiate some transformers >>> p_stft = pumpp.feature.STFTMag('stft', sr=sr, n_fft=n_fft, ... hop_length=hop_length) >>> p_beat = pumpp.task.BeatTransformer('beat', sr=sr, ... hop_length=hop_length) >>> # Apply the transformers to the data >>> data = pumpp.transform('test.ogg', 'test.jams', p_stft, p_beat) >>> # We'll sample 10 patches of duration = 32 frames >>> stream = pumpp.Sampler(10, 32, p_stft, p_beat) >>> # Apply the streamer to the data dict >>> for example in stream(data): ... process(data) '''
[docs] def __init__(self, n_samples, duration, *ops): self.n_samples = n_samples self.duration = duration fields = dict() for op in ops: fields.update(op.fields) # Pre-determine which fields have time-like indices self._time = {key: None for key in fields} for key in fields: if None in fields[key].shape: # Add one for the batching index self._time[key] = 1 + fields[key].shape.index(None)
[docs] def sample(self, data, interval): '''Sample a patch from the data object Parameters ---------- data : dict A data dict as produced by pumpp.transform interval : slice The time interval to sample Returns ------- data_slice : dict `data` restricted to `interval`. ''' data_slice = dict() for key in data: if '_valid' in key: continue index = [slice(None)] * data[key].ndim # if we have multiple observations for this key, pick one index[0] = np.random.randint(0, data[key].shape[0]) index[0] = slice(index[0], index[0] + 1) if self._time.get(key, None) is not None: index[self._time[key]] = interval data_slice[key] = data[key][index] return data_slice
[docs] def data_duration(self, data): '''Compute the valid data duration of a dict Parameters ---------- data : dict As produced by pumpp.transform Returns ------- length : int The minimum temporal extent of a dynamic observation in data ''' # Find all the time-like indices of the data lengths = [] for key in self._time: if self._time[key] is not None: lengths.append(data[key].shape[self._time[key]]) return min(lengths)
[docs] def __call__(self, data): '''Generate samples from a data dict. Parameters ---------- data : dict As produced by pumpp.transform Yields ------ data_sample : dict A sequence of patch samples from `data`, as parameterized by the sampler object. ''' duration = self.data_duration(data) for i in count(0): # are we done? if self.n_samples and i >= self.n_samples: break # Generate a sampling interval start = np.random.randint(0, duration - self.duration) yield self.sample(data, slice(start, start + self.duration))