Advanced Computing Platform for Theoretical Physics

Commit 9dee2460 authored by Bastien Achard's avatar Bastien Achard
Browse files

produces an iterator without calling minibatches function, thus not creating...

produces an iterator without calling minibatches function, thus not creating an array containing all the dataset
parent bc8e436e
......@@ -322,8 +322,8 @@ def iter_minibatches(lst, batch_size, dims, n_cycles=False, random_state=None):
----------
lst : list of array_like
Each item of the list will be sliced into mini batches in alignemnt with
the others.
Each item of the list will be sliced into mini batches in alignment
with the others.
batch_size : int
Size of each batch. Last batch might be smaller.
......@@ -347,19 +347,47 @@ def iter_minibatches(lst, batch_size, dims, n_cycles=False, random_state=None):
Infinite iterator of mini batches in random order (without
replacement).
"""
batches = [minibatches(i, batch_size, d) for i, d in zip(lst, dims)]
if len(batches) > 1:
if any(len(i) != len(batches[0]) for i in batches[1:]):
raise ValueError("containers to be batched have different lengths")
if any([d > 2 for d in dims]):
raise ValueError("cannot slice along a dimension larger than 2")
if any([len(arr.shape) <= d for (arr, d) in zip(lst, dims)]):
raise ValueError("cannot slice along that dimension (shape, dim): {}"
.format([(a.shape, d) for (a, d) in zip(lst, dims)]))
if any([arr.shape[d] != lst[0].shape[dims[0]]
for (arr, d) in zip(lst, dims)]):
raise ValueError("containers to be batched have different lengths: {}"
.format([a.shape[d] for (a, d) in zip(lst, dims)]))
# This alternative is to make this work with lists in the case of d == 0.
if dims[0] == 0:
n_batches, rest = divmod(len(lst[0]), batch_size)
else:
n_batches, rest = divmod(lst[0].shape[dims[0]], batch_size)
if rest:
n_batches += 1
counter = itertools.count()
if random_state is not None:
random.seed(random_state.normal())
while True:
indices = [i for i, _ in enumerate(batches[0])]
indices = range(n_batches)
while True:
random.shuffle(indices)
for i in indices:
yield tuple(b[i] for b in batches)
start = i * batch_size
end = (i + 1) * batch_size
batch = []
for (arr, d) in zip(lst, dims):
if d == 0:
batch.append(arr[start:end])
elif d == 1:
batch.append(arr[:, start:end])
elif d == 2:
batch.append(arr[:, :, start:end])
yield tuple(batch)
count = next(counter)
if n_cycles and count >= n_cycles:
raise StopIteration()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment