Advanced Computing Platform for Theoretical Physics

Commit d2817146 authored by Lei Wang's avatar Lei Wang
Browse files

added inference code

parent bf7ac037
......@@ -26,4 +26,3 @@ save_path += 'kejax' \
+ '_k' + str(args.kernelsize) \
+ '_b' + str(args.batchsize) \
+ '_lr' + str(args.lr)
print ('Save path:', save_path)
import jax
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
import io, time
import matplotlib.pyplot as plt
from data import make_training_data
from model import make_network
from loss import make_loss
from args import args
import checkpoint
ckpt_restore_path = checkpoint.get_restore_path(args.restore_path)
ckpt_restore_filename = checkpoint.find_last_checkpoint(ckpt_restore_path)
xmesh, get_batch = make_training_data(args.N, args.G)
network = make_network(args.depth, args.channels, args.kernelsize)
key = jax.random.PRNGKey(42)
params = network.init(key, xmesh)
if ckpt_restore_filename:
t_init, params = checkpoint.restore(ckpt_restore_filename, args.batchsize)
else:
print('No checkpoint found')
vs, rhos, Es = get_batch(key, args.batchsize)
ts = jax.vmap(network.apply, (None, None, 0))(params, None, rhos)
print(ts.shape, jnp.sum(rhos* ts, axis=1).shape)
x0 = jnp.log(jnp.mean(rhos/args.N, axis=0))
loss_fn, predict_fn = make_loss(network, x0, args.N)
predict_rhos = predict_fn(params, vs)
plt.subplot(121)
plt.plot(xmesh, predict_rhos[0, :], label='predicted')
plt.plot(xmesh, rhos[0, :], label='true')
plt.subplot(122)
t = lambda rho: jnp.sum(network.apply(params, None, rho) *rho)
plt.plot(xmesh, jax.grad(t)(rhos[0, :]), label='predicted')
plt.plot(xmesh, -vs[0, :], label='-v')
plt.legend()
plt.show()
......@@ -12,6 +12,7 @@ from loss import make_loss
from args import args, save_path
import checkpoint
print ('Save path:', save_path)
ckpt_save_path = checkpoint.create_save_path(save_path)
ckpt_restore_path = checkpoint.get_restore_path(args.restore_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment