Advanced Computing Platform for Theoretical Physics

Commit 9141238a authored by Maximilian Soelch's avatar Maximilian Soelch
Browse files

Merge branch 'iter_minibatches-iterator' of https://github.com/bachard/climin...

Merge branch 'iter_minibatches-iterator' of https://github.com/bachard/climin into bachard-iter_minibatches-iterator
parents 1b6d4c1f 9dee2460
......@@ -322,8 +322,8 @@ def iter_minibatches(lst, batch_size, dims, n_cycles=False, random_state=None):
----------
lst : list of array_like
Each item of the list will be sliced into mini batches in alignemnt with
the others.
Each item of the list will be sliced into mini batches in alignment
with the others.
batch_size : int
Size of each batch. Last batch might be smaller.
......@@ -347,19 +347,47 @@ def iter_minibatches(lst, batch_size, dims, n_cycles=False, random_state=None):
Infinite iterator of mini batches in random order (without
replacement).
"""
batches = [minibatches(i, batch_size, d) for i, d in zip(lst, dims)]
if len(batches) > 1:
if any(len(i) != len(batches[0]) for i in batches[1:]):
raise ValueError("containers to be batched have different lengths")
if any([d > 2 for d in dims]):
raise ValueError("cannot slice along a dimension larger than 2")
if any([len(arr.shape) <= d for (arr, d) in zip(lst, dims)]):
raise ValueError("cannot slice along that dimension (shape, dim): {}"
.format([(a.shape, d) for (a, d) in zip(lst, dims)]))
if any([arr.shape[d] != lst[0].shape[dims[0]]
for (arr, d) in zip(lst, dims)]):
raise ValueError("containers to be batched have different lengths: {}"
.format([a.shape[d] for (a, d) in zip(lst, dims)]))
# This alternative is to make this work with lists in the case of d == 0.
if dims[0] == 0:
n_batches, rest = divmod(len(lst[0]), batch_size)
else:
n_batches, rest = divmod(lst[0].shape[dims[0]], batch_size)
if rest:
n_batches += 1
counter = itertools.count()
if random_state is not None:
random.seed(random_state.normal())
while True:
indices = [i for i, _ in enumerate(batches[0])]
indices = range(n_batches)
while True:
random.shuffle(indices)
for i in indices:
yield tuple(b[i] for b in batches)
start = i * batch_size
end = (i + 1) * batch_size
batch = []
for (arr, d) in zip(lst, dims):
if d == 0:
batch.append(arr[start:end])
elif d == 1:
batch.append(arr[:, start:end])
elif d == 2:
batch.append(arr[:, :, start:end])
yield tuple(batch)
count = next(counter)
if n_cycles and count >= n_cycles:
raise StopIteration()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment