Advanced Computing Platform for Theoretical Physics

Commit ac966158 authored by Lei Wang's avatar Lei Wang
Browse files

submit use density loss and not

parent c7afe796
......@@ -2,11 +2,11 @@ import subprocess
import numpy as np
import time
nickname = 'first_try'
nickname = 'fixunit'
###############################
batchsizes = [16, 32]
kernelsizes = [16, 32]
batchsizes = [32]
kernelsizes = [100]
use_density_loss = True
###############################
prog = '../src/train.py'
resfolder = '/data1/wanglei/kejax/' + nickname + '/'
......
#!/bin/bash -l
#
#PBS -l nodes=1:ppn=1:gpus=1
#PBS -l walltime=24:00:00
#PBS -N ../jobs/k32
#PBS -o ../jobs/k32.log
#PBS -j oe
#PBS -V
ncpu=`cat $PBS_NODEFILE | wc -l`
echo "Running on" uniq -c $PBS_NODEFILE
cd $PBS_O_WORKDIR
echo "Running from $PBS_O_WORKDIR"
echo "CUDA devices $CUDA_VISIBLE_DEVICES"
echo Job started at `date`
python ../src/train.py --k 32
echo Job finished at `date`
......@@ -47,13 +47,16 @@ if __name__=='__main__':
args = {'k':k,
'b':b,
'folder':resfolder
'folder':resfolder,
'use_density_loss':use_density_loss
}
jobname = jobdir
for key, val in args.items():
if not(key in ['folder', 'restore_path']):
if not(key in ['folder', 'restore_path', 'use_density_loss']):
jobname += str(key) + str(val) + '_'
jobname = jobname[:-1]
if key == 'use_density_loss' and val:
jobname += 'usedensityloss' + '_'
jobname = jobname[:-1] # remove last '_'
jobid = submitJob(prog,args,jobname,run=input.run, wait=input.waitfor)
......@@ -9,6 +9,7 @@ parser.add_argument("--G", type=int, default=300, help="G")
parser.add_argument("--batchsize", type=int, default=16, help="batchsize")
parser.add_argument("--epochs", type=int, default=100000, help="epochs")
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
parser.add_argument("--use_density_loss", type=bool, default=False, help="")
#model parameters
parser.add_argument("--depth", type=int, default=2, help="depth")
......@@ -25,4 +26,7 @@ save_path += 'kejax' \
+ '_c' + str(args.channels) \
+ '_k' + str(args.kernelsize) \
+ '_b' + str(args.batchsize) \
+ '_lr' + str(args.lr)
+ '_lr' + str(args.lr)
if args.use_density_loss:
save_path += '_usedensityloss'
......@@ -30,7 +30,7 @@ ts = jax.vmap(network.apply, (None, None, 0))(params, None, rhos)
print(ts.shape, jnp.sum(rhos* ts, axis=1).shape)
x0 = jnp.log(jnp.mean(rhos/args.N, axis=0))
loss_fn, predict_fn = make_loss(network, x0, args.N, dx)
_, _, _, predict_fn = make_loss(network, x0, args.N, dx)
predict_rhos = predict_fn(params, vs)
......
......@@ -23,14 +23,7 @@ def make_loss(network, x0, N, dx):
def density_loss(params, vs, rhos, Es):
del Es
rhos_star = predict(params, vs)
return jnp.mean(jnp.sum(((rhos_star - rhos)/dx)**2*dx, axis=1))
@jax.jit
def predict(params, vs):
xs_star = jax.vmap(euler_solver, (None, 0, None))(x0, vs, params)
xs_star = xs_star - jax.scipy.special.logsumexp(xs_star, axis=-1, keepdims=True)
rhos_star = jnp.exp(xs_star) * N
return rhos_star
return jnp.mean(jnp.sum(((rhos_star - rhos)/dx)**2*dx, axis=1))
def energy_loss(params, vs, rhos, Es):
del vs
......@@ -42,11 +35,12 @@ def make_loss(network, x0, N, dx):
t = lambda rho: jnp.sum(network.apply(params, None, rho)*rho)
return jnp.mean(jnp.sum((jax.vmap(jax.grad(t))(rhos) + vs)**2 *dx, axis=1))
def total_loss(params, vs, rhos, Es):
return energy_loss(params, vs, rhos, Es) + \
density_loss(params, vs, rhos, Es) + \
potential_loss(params, vs, rhos, Es)
return total_loss, predict
@jax.jit
def predict(params, vs):
xs_star = jax.vmap(euler_solver, (None, 0, None))(x0, vs, params)
xs_star = xs_star - jax.scipy.special.logsumexp(xs_star, axis=-1, keepdims=True)
rhos_star = jnp.exp(xs_star) * N
return rhos_star
return density_loss, energy_loss, potential_loss, predict
......@@ -38,7 +38,14 @@ ts = jax.vmap(network.apply, (None, None, 0))(params, None, rhos)
print(ts.shape, jnp.sum(rhos* ts, axis=1).shape)
x0 = jnp.log(jnp.mean(rhos*dx/args.N, axis=0))
loss_fn, predict_fn = make_loss(network, x0, args.N, dx)
density_loss, energy_loss, potential_loss, predict_fn = make_loss(network, x0, args.N, dx)
def loss_fn(params, vs, rhos, Es):
total_loss = energy_loss(params, vs, rhos, Es) + \
potential_loss(params, vs, rhos, Es)
if args.use_density_loss:
total_loss += density_loss(params, vs, rhos, Es)
return total_loss
opt_init, opt_update, get_params = optimizers.adam(step_size=args.lr)
opt_state = opt_init(params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment