#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
Data subsampling
================
.. autosummary::
:toctree: generated/
Sampler
SequentialSampler
VariableLengthSampler
'''
from itertools import count
import six
import numpy as np
from .base import Slicer
from .exceptions import ParameterError, DataError
__all__ = ['Sampler', 'SequentialSampler', 'VariableLengthSampler']
[docs]class Sampler(Slicer):
'''Generate samples uniformly at random from a pumpp data dict.
Attributes
----------
n_samples : int or None
the number of samples to generate.
If `None`, generate indefinitely.
duration : int > 0
the duration (in frames) of each sample
random_state : None, int, or np.random.RandomState
If int, random_state is the seed used by the random number
generator;
If RandomState instance, random_state is the random number
generator;
If None, the random number generator is the RandomState instance
used by np.random.
ops : array of pumpp.feature.FeatureExtractor or pumpp.task.BaseTaskTransformer
The operators to include when sampling data.
Examples
--------
>>> # Set up the parameters
>>> sr, n_fft, hop_length = 22050, 512, 2048
>>> # Instantiate some transformers
>>> p_stft = pumpp.feature.STFTMag('stft', sr=sr, n_fft=n_fft,
... hop_length=hop_length)
>>> p_beat = pumpp.task.BeatTransformer('beat', sr=sr,
... hop_length=hop_length)
>>> # Apply the transformers to the data
>>> data = pumpp.transform('test.ogg', 'test.jams', p_stft, p_beat)
>>> # We'll sample 10 patches of duration = 32 frames
>>> stream = pumpp.Sampler(10, 32, p_stft, p_beat)
>>> # Apply the streamer to the data dict
>>> for example in stream(data):
... process(data)
'''
[docs] def __init__(self, n_samples, duration, *ops, **kwargs):
super(Sampler, self).__init__(*ops)
self.n_samples = n_samples
self.duration = duration
random_state = kwargs.pop('random_state', None)
if random_state is None:
self.rng = np.random
elif isinstance(random_state, int):
self.rng = np.random.RandomState(seed=random_state)
elif isinstance(random_state, np.random.RandomState):
self.rng = random_state
else:
raise ParameterError('Invalid random_state={}'.format(random_state))
[docs] def sample(self, data, interval):
'''Sample a patch from the data object
Parameters
----------
data : dict
A data dict as produced by pumpp.Pump.transform
interval : slice
The time interval to sample
Returns
-------
data_slice : dict
`data` restricted to `interval`.
'''
data_slice = dict()
for key in data:
if '_valid' in key:
continue
index = [slice(None)] * data[key].ndim
# if we have multiple observations for this key, pick one
index[0] = self.rng.randint(0, data[key].shape[0])
index[0] = slice(index[0], index[0] + 1)
for tdim in self._time[key]:
index[tdim] = interval
data_slice[key] = data[key][tuple(index)]
return data_slice
[docs] def indices(self, data):
'''Generate patch indices
Parameters
----------
data : dict of np.ndarray
As produced by pumpp.transform
Yields
------
start : int >= 0
The start index of a sample patch
'''
duration = self.data_duration(data)
if self.duration > duration:
raise DataError('Data duration={} is less than '
'sample duration={}'.format(duration, self.duration))
while True:
# Generate a sampling interval
yield self.rng.randint(0, duration - self.duration + 1)
[docs] def __call__(self, data):
'''Generate samples from a data dict.
Parameters
----------
data : dict
As produced by pumpp.transform
Yields
------
data_sample : dict
A sequence of patch samples from `data`,
as parameterized by the sampler object.
'''
if self.n_samples:
counter = six.moves.range(self.n_samples)
else:
counter = count(0)
for _, start in six.moves.zip(counter, self.indices(data)):
yield self.sample(data, slice(start, start + self.duration))
[docs]class SequentialSampler(Sampler):
'''Sample patches in sequential (temporal) order
Attributes
----------
duration : int > 0
the duration (in frames) of each sample
stride : int > 0
The number of frames to advance between samples.
By default, matches `duration` so there is no overlap.
ops : array of pumpp.feature.FeatureExtractor or pumpp.task.BaseTaskTransformer
The operators to include when sampling data.
random_state : None, int, or np.random.RandomState
If int, random_state is the seed used by the random number
generator;
If RandomState instance, random_state is the random number
generator;
If None, the random number generator is the RandomState instance
See Also
--------
Sampler
'''
[docs] def __init__(self, duration, *ops, **kwargs):
stride = kwargs.pop('stride', None)
super(SequentialSampler, self).__init__(None, duration, *ops, **kwargs)
if stride is None:
stride = duration
if not stride > 0:
raise ParameterError('Invalid patch stride={}'.format(stride))
self.stride = stride
[docs] def indices(self, data):
'''Generate patch start indices
Parameters
----------
data : dict of np.ndarray
As produced by pumpp.transform
Yields
------
start : int >= 0
The start index of a sample patch
'''
duration = self.data_duration(data)
for start in range(0, duration - self.duration, self.stride):
yield start
[docs]class VariableLengthSampler(Sampler):
'''Sample random patches like a `Sampler`, but allow for
output patches to be less than the target duration when the
data is too short.
Attributes
----------
n_samples : int or None
the number of samples to generate.
If `None`, generate indefinitely.
min_duration : int > 0
The minimum duration (in frames) of each sample
max_duration : int > 0
the maximum duration (in frames) of each sample
random_state : None, int, or np.random.RandomState
If int, random_state is the seed used by the random number
generator;
If RandomState instance, random_state is the random number
generator;
If None, the random number generator is the RandomState instance
used by np.random.
ops : array of pumpp.feature.FeatureExtractor or pumpp.task.BaseTaskTransformer
The operators to include when sampling data.
See Also
--------
Sampler
'''
[docs] def __init__(self, n_samples, min_duration, max_duration, *ops, **kwargs):
super(VariableLengthSampler, self).__init__(n_samples, max_duration,
*ops, **kwargs)
if min_duration < 1:
raise ParameterError('min_duration={} must be '
'at least 1.'.format(min_duration))
if max_duration < min_duration:
raise ParameterError('max_duration={} must be at least '
'min_duration={}'.format(max_duration,
min_duration))
self.min_duration = min_duration
[docs] def indices(self, data):
'''Generate patch indices
Parameters
----------
data : dict of np.ndarray
As produced by pumpp.transform
Yields
------
start : int >= 0
The start index of a sample patch
'''
duration = self.data_duration(data)
while True:
# Generate a sampling interval
yield self.rng.randint(0, duration - self.min_duration + 1)
[docs] def __call__(self, data):
'''Generate samples from a data dict.
Parameters
----------
data : dict
As produced by pumpp.transform
Yields
------
data_sample : dict
A sequence of patch samples from `data`,
as parameterized by the sampler object.
'''
if self.n_samples:
counter = six.moves.range(self.n_samples)
else:
counter = count(0)
duration = self.data_duration(data)
for _, start in six.moves.zip(counter, self.indices(data)):
yield self.sample(data,
slice(start, min(duration, start + self.duration)))