Advanced Computing Platform for Theoretical Physics

Commit 93ecb95a authored by Lei Wang's avatar Lei Wang
Browse files

initial version

parents
# -*- coding: utf-8 -*-
"""ke-functional.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1eJ82YWR19B9xd9ZcXFb9x8lSzl_N0MtE
"""
import jax
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.scipy.optimize
from jax.experimental import optimizers
from tqdm import trange
import matplotlib.pyplot as plt
import haiku as hk
import jaxopt
"""Set up a one dimensional space following https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.108.253002, with some functions to generate training data (potential, density, kinetic energy)
"""
xmin, xmax, N = -0.5, 0.5, 100
xmesh = jnp.linspace(xmin, xmax, N, endpoint=False)
batchsize = 1
def make_training_data(xmesh):
def buildv(key):
res = jnp.zeros_like(xmesh)
for _ in range(3):
r = jax.random.uniform(key, (3,))
a = 9*r[0]+1
b = 0.2*r[1]+ 0.4
c = 0.07*r[2] + 0.03
res = res - a * jnp.exp(-(xmesh+0.5-b)**2/(2*c**2))
return res
def solver(v):
N = v.shape[0]
h = xmesh[1] - xmesh[0]
K = -0.5/h**2 * (jnp.diag(-2 * jnp.ones(N))
+ jnp.diag(jnp.ones(N - 1), k=1)
+ jnp.diag(jnp.ones(N - 1), k=-1))
w, psi = jnp.linalg.eigh(K + jnp.diag(v))
rho = psi[:, 0]**2
return rho, w[0] - jnp.sum(rho * v)
def get_batch(key, batchsize):
subkeys = jax.random.split(key, batchsize)
vs = jax.vmap(buildv)(subkeys)
rhos, Es = jax.vmap(solver)(vs)
return vs, rhos, Es
return get_batch
"""Now, set the kinetic energy functional as a simple MLP which maps density to kinetic energy"""
#t[n]
def forward_fn(x):
x = jnp.expand_dims(x, axis=1)
for _ in range(4):
resblock = hk.Sequential([
hk.Conv1D(output_channels=16, kernel_shape=5, padding='SAME'),
jax.nn.softplus,
])
x = x + resblock(x)
return hk.Conv1D(output_channels=1, kernel_shape=5, padding='SAME')(x).squeeze(axis=1)
network = hk.transform(forward_fn)
key = jax.random.PRNGKey(42)
params = network.init(key, xmesh)
get_batch = make_training_data(xmesh)
"""Let's try to do a fit to the kinetic energy"""
def energy_loss(params, rhos, Es):
ts = jax.vmap(network.apply, (None, None, 0))(params, None, rhos)
return jnp.mean((jnp.sum(rhos* ts, axis=1) - Es)**2)
opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2)
opt_state = opt_init(params)
_ ,rhos, Es = get_batch(key, batchsize)
ts = jax.vmap(network.apply, (None, None, 0))(params, None, rhos)
print(ts.shape, jnp.sum(rhos* ts, axis=1).shape)
@jax.jit
def step(i, key, opt_state):
params = get_params(opt_state)
_ ,rhos, Es = get_batch(key, batchsize)
value, grad = jax.value_and_grad(energy_loss)(params, rhos, Es)
opt_state = opt_update(i, grad, opt_state)
return value, opt_state
with trange(1000) as t:
for epoch in t:
key, subkey = jax.random.split(key)
value, opt_state = step(epoch, subkey, opt_state)
t.set_postfix(loss=value)
#from jax.flatten_util import ravel_pytree
#p, unravel = ravel_pytree(params)
#f = lambda p: energy_loss(unravel(p), rhos, Es)
#res = jax.scipy.optimize.minimize(f, p, method='BFGS')
#print("success:", res.success, "\nniterations:", res.nit, "\nfinal loss:", res.fun)
"""Now, the real thing: we define the loss function as the integrated density difference"""
def make_loss(network):
def total_energy(x, v, params):
x = x - jax.scipy.special.logsumexp(x)
rho = jnp.exp(x)
t = network.apply(params, None, rho)
return jnp.sum(rho * (t+v) )
def euler_equation(x, v, params):
x = x - jax.scipy.special.logsumexp(x)
rho = jnp.exp(x)
t = lambda rho: jnp.sum(network.apply(params, None, rho)*rho)
return jax.grad(t)(rho) + v
#euler_equation = jax.grad(total_energy)
@jaxopt.implicit_diff.custom_root(euler_equation)
def euler_solver(x0, v, params):
f = lambda x: total_energy(x, v, params)
res = jax.scipy.optimize.minimize(f, x0, method='BFGS')
return res.x
def density_loss(params, vs, rhos, Es):
del Es
rhos_star = predict(params, vs)
return jnp.mean(((rhos_star - rhos)*N)**2)
#return jnp.sum ( rhos * (jnp.log(rhos) - jnp.log(rhos_star)) )/batchsize
def predict(params, vs):
xs_star = jax.vmap(euler_solver, (0, 0, None))(jnp.zeros_like(vs), vs, params)
xs_star = xs_star - jax.scipy.special.logsumexp(xs_star, axis=-1, keepdims=True)
rhos_star = jnp.exp(xs_star)
return rhos_star
def energy_loss(params, vs, rhos, Es):
del vs
ts = jax.vmap(network.apply, (None, None, 0))(params, None, rhos)
return jnp.mean((jnp.sum(rhos* ts, axis=1) - Es)**2)
def total_loss(params, vs, rhos, Es):
return density_loss(params, vs, rhos, Es) + energy_loss(params, vs, rhos, Es)
return total_loss, predict
loss_fn, predict_fn = make_loss(network)
opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2)
opt_state = opt_init(params)
@jax.jit
def step(i, key, opt_state):
params = get_params(opt_state)
vs, rhos, Es = get_batch(key, batchsize)
value, grad = jax.value_and_grad(loss_fn)(params, vs, rhos, Es)
opt_state = opt_update(i, grad, opt_state)
return value, opt_state
with trange(1000) as t:
for epoch in t:
key, subkey = jax.random.split(key)
value, opt_state = step(epoch, subkey, opt_state)
t.set_postfix(loss=value)
"""We generate some more data, and see how good is the prediction"""
key, subkey = jax.random.split(key)
vs, rhos, Es = get_batch(subkey, batchsize)
predict_rhos = predict_fn(get_params(opt_state), vs)
plt.plot(xmesh, rhos[0, :])
plt.plot(xmesh, predict_rhos[0, :])
plt.show()
t = lambda rho: jnp.sum(network.apply(get_params(opt_state), None, rho) *rho)
plt.plot(xmesh, jax.grad(t)(rhos[0, :]))
plt.show()
"""#sandbox
## TODO
- [x] euler solver with BFGS optimization
- [x] loss with nstar = (nstar - rhos)
- [x] implicit gradient
- [x] switch to haiku and use float64
- [ ] basis for density ?
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment