Advanced Computing Platform for Theoretical Physics

Commit 005c9239 authored by Maximilian Soelch's avatar Maximilian Soelch
Browse files

Merge branch 'bachard-iter_minibatches-iterator'

parents 71824107 262f389b
......@@ -261,6 +261,55 @@ def empty_with_views(shapes, empty_func=np.empty):
return flat, views
def arbitrary_slice(arr, start, stop=None, axis=0):
"""Return a slice from `start` to `stop` in dimension `axis` of array `arr`.
Parameters
----------
arr : array_like
Can be numpy ndarray, hdf5 dataset, or list.
If ``arr`` is a list, ``axis`` must be 0.
start : int
Index at which to start slicing.
stop : int, optional [default: None]
Index at which to stop slicing.
If not specified, the axis is sliced until its end.
axis : int, optional [default: 0]
Axis along which should be sliced
Returns
-------
slice : array_like
The respective slice of ``arr``
"""
if type(arr) is list:
if axis == 0:
return arr[start:stop]
else:
raise ValueError("Cannot slice a list in non-zero axis {}"
.format(axis))
n_axes = len(arr.shape)
if axis >= n_axes:
raise IndexError('Argument `axis` with value {} out of range. '
'Must be smaller than rank {} of `arr`.'
.format(axis, n_axes))
this_slice = [slice(None) for _ in range(n_axes)]
this_slice[axis] = slice(start, stop)
return arr[tuple(this_slice)]
def minibatches(arr, batch_size, d=0):
"""Return a list of views of the given arr.
......@@ -308,11 +357,11 @@ def minibatches(arr, batch_size, d=0):
return res
def iter_minibatches(lst, batch_size, dims, n_cycles=False, random_state=None):
def iter_minibatches(lst, batch_size, dims, n_cycles=None, random_state=None,
discard_illsized_batch=False):
"""Return an iterator that successively yields tuples containing aligned
minibatches of size `batch_size` from slicable objects given in `lst`, in
random order without replacement.
Because different containers might require slicing over different
dimensions, the dimension of each container has to be givens as a list
`dims`.
......@@ -322,8 +371,8 @@ def iter_minibatches(lst, batch_size, dims, n_cycles=False, random_state=None):
----------
lst : list of array_like
Each item of the list will be sliced into mini batches in alignemnt with
the others.
Each item of the list will be sliced into mini batches in alignment
with the others.
batch_size : int
Size of each batch. Last batch might be smaller.
......@@ -332,37 +381,64 @@ def iter_minibatches(lst, batch_size, dims, n_cycles=False, random_state=None):
Aligned with ``lst``, gives the dimension along which the data samples
are separated.
n_cycles : int or False, optional [default: False]
Number of cycles after which to stop the iterator. If ``False``, will
n_cycles : int, optional [default: None]
Number of cycles after which to stop the iterator. If ``None``, will
yield forever.
random_state : a numpy.random.RandomState object, optional [default : None]
Random number generator that will act as a seed for the minibatch order
Random number generator that will act as a seed for the minibatch order.
discard_illsized_batch : bool, optional [default : False]
If ``True`` and the length of the sliced dimension is not divisible by
``batch_size``, the leftover samples are discarded.
Returns
-------
batches : iterator
Infinite iterator of mini batches in random order (without
replacement).
Infinite iterator of mini batches in random order (without replacement).
"""
batches = [minibatches(i, batch_size, d) for i, d in zip(lst, dims)]
if len(batches) > 1:
if any(len(i) != len(batches[0]) for i in batches[1:]):
raise ValueError("containers to be batched have different lengths")
counter = itertools.count()
try:
# case distinction for handling lists
dm_result = [divmod(len(arr), batch_size)
if d == 0 else divmod(arr.shape[d], batch_size)
for (arr, d) in zip(lst, dims)]
except AttributeError:
raise AttributeError("'list' object has no attribute 'shape'. "
"Trying to slice a list in a non-zero axis.")
except IndexError:
raise IndexError("tuple index out of range. "
"Trying to slice along a non-existing dimension.")
# check if all to-be-sliced dimensions have the same length
if dm_result.count(dm_result[0]) == len(dm_result):
n_batches, rest = dm_result[0]
if rest and not discard_illsized_batch:
n_batches += 1
else:
raise ValueError("The axes along which to slice have unequal lengths.")
if random_state is not None:
random.seed(random_state.normal())
counter = itertools.count()
count = next(counter)
while True:
indices = [i for i, _ in enumerate(batches[0])]
indices = range(n_batches)
while True:
if n_cycles is not None and count >= n_cycles:
raise StopIteration()
count = next(counter)
random.shuffle(indices)
for i in indices:
yield tuple(b[i] for b in batches)
count = next(counter)
if n_cycles and count >= n_cycles:
raise StopIteration()
start = i * batch_size
stop = (i + 1) * batch_size
batch = [arbitrary_slice(arr, start, stop, axis) for (arr, axis)
in zip(lst, dims)]
yield tuple(batch)
class OptimizerDistribution(object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment