Advanced Computing Platform for Theoretical Physics

Commit da689651 authored by Lei Wang's avatar Lei Wang
Browse files

write loss to log

parent c5bd9aec
......@@ -15,6 +15,7 @@ import jax.scipy.optimize
from jax.experimental import optimizers
from tqdm import trange
import io
import matplotlib.pyplot as plt
import haiku as hk
......@@ -159,25 +160,16 @@ def step(i, key, opt_state):
opt_state = opt_update(i, grad, opt_state)
return value, opt_state
plt.ion()
loss_history = []
with trange(10000) as t:
for epoch in t:
key, subkey = jax.random.split(key)
value, opt_state = step(epoch, subkey, opt_state)
t.set_postfix(loss=value)
loss_history.append(value)
if (epoch % 10==0):
plt.cla()
plt.plot(loss_history)
plt.xscale('log')
plt.yscale('log')
plt.draw()
plt.pause(0.01)
plt.ioff()
with io.open('loss.log', 'a', buffering=1, newline='\n') as logfile:
with trange(10000) as t:
for epoch in t:
key, subkey = jax.random.split(key)
value, opt_state = step(epoch, subkey, opt_state)
t.set_postfix(loss=value)
message = ('{} ' + 1*'{:.5f} ').format(epoch, value)
logfile.write(message + u'\n')
def test():
predict_rhos = predict_fn(get_params(opt_state), vs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment