Advanced Computing Platform for Theoretical Physics

Commit 27a3c051 authored by Justin Bayer's avatar Justin Bayer
Browse files

added momentum to rmsprop

parent 8f4d5127
......@@ -8,16 +8,18 @@ from base import Minimizer
class RmsProp(Minimizer):
def __init__(self, wrt, fprime, steprate, decay,
def __init__(self, wrt, fprime, steprate, decay, momentum=0,
args=None, logfunc=None):
super(RmsProp, self).__init__(wrt, args=args, logfunc=logfunc)
self.fprime = fprime
self.steprate = steprate
self.decay = decay
self.momentum = momentum
def __iter__(self):
moving_mean_squared = None
step_m1 = 0
for i, (args, kwargs) in enumerate(self.args):
gradient = self.fprime(self.wrt, *args, **kwargs)
......@@ -27,7 +29,9 @@ class RmsProp(Minimizer):
self.decay * moving_mean_squared
+ (1 - self.decay) * gradient**2)
step = self.steprate * gradient / np.sqrt(moving_mean_squared + 1e-8)
step += step_m1 * self.momentum
self.wrt -= step
step_m1 = step
yield dict(args=args, kwargs=kwargs, gradient=gradient,
n_iter=i,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment