Advanced Computing Platform for Theoretical Physics

Commit 8d99a80b authored by Lei Wang's avatar Lei Wang
Browse files

gives max_iter and epsilon to while_loop

parent 77e44c43
# Learning kinetic energy functional via differentiable optimization
Here, the key is to include the integrated density difference in the loss function. Training this then require one to
differentiate through the Euler equation solver of the orbital-free density functionals, which can be conveniently done with `jaxopt.implicit_diff.custom_root` of https://arxiv.org/abs/2105.15183.
differentiate through the Euler equation solver of the orbital-free density functionals, which can be conveniently done with `jaxopt.implicit_diff.custom_fixed_point` of https://arxiv.org/abs/2105.15183.
Traning will be more expensive (10x slow than just use total energy and its functional derivative in the loss), but the hope is the learned kinetic energy functional is more generalizable.
......@@ -12,10 +12,9 @@ pip install git+https://github.com/google/jaxopt
pip install git+https://github.com/deepmind/dm-haiku
```
## To run
```python
python src/train.py --k 100 --b 32 --use_density_loss
python src/inference.py --restore_path --k 100 --b 32 data/kejax_N1_G300_d2_c16_k100_b32_lr0.0001_usedensityloss
python src/train.py --N 2 --k 100 --b 32 --use_density_loss
python src/inference.py --N 2 --k 100 --b 32 --restore_path data/kejax_N2_G300_d2_c16_k100_b32_lr0.0001_usedensityloss
```
......@@ -10,6 +10,8 @@ parser.add_argument("--batchsize", type=int, default=16, help="batchsize")
parser.add_argument("--epochs", type=int, default=100000, help="epochs")
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
parser.add_argument("--use_density_loss", action='store_true', help="")
parser.add_argument("--max_iter", type=int, default=100, help="maximum number of iterations in euler solver")
parser.add_argument("--epsilon", type=float, default=1e-8, help="converge criterion of density norm in euler solver")
#model parameters
parser.add_argument("--depth", type=int, default=2, help="depth")
......
import jax
from jax.config import config
config.update("jax_enable_x64", True)
#config.update('jax_disable_jit', True)
import jax.numpy as jnp
import io, time
import matplotlib.pyplot as plt
......@@ -30,17 +32,20 @@ ts = jax.vmap(network.apply, (None, None, 0))(params, None, rhos)
print(ts.shape, jnp.sum(rhos* ts, axis=1).shape)
rho0 = jnp.mean(jnp.sqrt(rhos), axis=0)
density_loss, energy_loss, potential_loss, predict_fn, kinetic_energy_fn = make_loss(network, K, rho0, args.N, dx)
density_loss, energy_loss, potential_loss, predict_fn, kinetic_energy_fn = make_loss(network, K, rho0, args.N, dx, args.max_iter, args.epsilon)
print ('energy_loss:', energy_loss(params, vs, rhos, Es, mus))
print ('potential_loss:', potential_loss(params, vs, rhos, Es, mus))
print ('density_loss:', density_loss(params, vs, rhos, Es, mus))
predict_rhos = predict_fn(params, vs)
predicted_rhos, converge = predict_fn(params, vs)
if not converge:
print ('!!! Euler solver did not converge in %g steps to reach %g accuracy'%(args.max_iter, args.epsilon))
for i in range(args.batchsize):
plt.subplot(121)
plt.plot(xmesh, predict_rhos[i, :], label=r'$\rho$')
plt.plot(xmesh, predicted_rhos[i, :], label=r'$\rho$')
plt.plot(xmesh, rhos[i, :], label='true')
plt.legend()
......
......@@ -4,7 +4,7 @@ import jax.scipy.optimize
import jaxopt
def make_loss(network, K, x0, N, dx):
def make_loss(network, K, x0, N, dx, max_iter, epsilon):
def kinetic_energy(rho, params):
phi = jnp.sqrt(rho)
......@@ -19,20 +19,20 @@ def make_loss(network, K, x0, N, dx):
@jaxopt.implicit_diff.custom_fixed_point(ofdft_solver)
def euler_solver(x0, v, params):
def _body_fn(carry):
_, x = carry
return x, ofdft_solver(x,v,params)
i, _, x = carry
return i+1, x, ofdft_solver(x,v,params)
def _cond_fn(carry):
x_prev, x = carry
return jnp.linalg.norm(x_prev-x) > 1e-8
i, x_prev, x = carry
return (i< max_iter) & (jnp.linalg.norm(x_prev-x) > epsilon)
init_carry = (x0, ofdft_solver(x0,v,params))
_, x_star = jax.lax.while_loop(_cond_fn, _body_fn, init_carry)
return x_star
init_carry = (0, x0, ofdft_solver(x0,v,params))
n_iter, _, x_star = jax.lax.while_loop(_cond_fn, _body_fn, init_carry)
return x_star, (n_iter < max_iter)
def density_loss(params, vs, rhos, Es, mus):
del Es, mus
rhos_star = predict(params, vs)
rhos_star, _ = predict(params, vs)
return jnp.mean(jnp.sum(((rhos_star - rhos)/dx)**2*dx, axis=1))
def energy_loss(params, vs, rhos, Es, mus):
......@@ -47,7 +47,8 @@ def make_loss(network, K, x0, N, dx):
@jax.jit
def predict(params, vs):
return jax.vmap(euler_solver, (None, 0, None))(x0, vs, params)
rhos_star, converge = jax.vmap(euler_solver, (None, 0, None))(x0, vs, params)
return rhos_star, jnp.alltrue(converge)
return density_loss, energy_loss, potential_loss, predict, kinetic_energy
import jax
from jax.config import config
config.update("jax_enable_x64", True)
#config.update('jax_disable_jit', True)
import jax.numpy as jnp
from jax.experimental import optimizers
......@@ -40,7 +39,7 @@ ts = jax.vmap(network.apply, (None, None, 0))(params, None, rhos)
print(ts.shape, jnp.sum(rhos* ts, axis=1).shape)
rho0 = jnp.mean(jnp.sqrt(rhos), axis=0)
density_loss, energy_loss, potential_loss, predict_fn, kinetic_energy_fn = make_loss(network, K, rho0, args.N, dx)
density_loss, energy_loss, potential_loss, predict_fn, kinetic_energy_fn = make_loss(network, K, rho0, args.N, dx, args.max_iter, args.epsilon)
def loss_fn(params, vs, rhos, Es, mus):
total_loss = energy_loss(params, vs, rhos, Es, mus) + \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment