Advanced Computing Platform for Theoretical Physics

Commit ea986ce3 authored by Lei Wang's avatar Lei Wang
Browse files

silu activation, no bias in conv, coef in args

parent db5dabf8
......@@ -15,6 +15,6 @@ pip install git+https://github.com/deepmind/dm-haiku
## To run
```python
python src/train.py --N 2 --use_density_loss
python src/inference.py --N 2 --restore_path data/kejax_N2_G300_d2_c16_k100_b32_lr0.0001_dnls
python src/train.py --N 2 --dnls 1.0
python src/inference.py --N 2 --restore_path data/kejax_N2_G300_d2_c16_k100_b32_lr0.0001_en1.0_pt0.0_dn1.0
```
......@@ -2,16 +2,15 @@ import subprocess
import numpy as np
import time
nickname = 'fixed_point'
nickname = 'silu-nobias-silu'
###############################
use_potential_loss = True
use_density_loss = False
ptls, dnls = 0.0, 1.0
N = 2
lr = 1e-4
lr = 1e-3
G = 300
depths = [2]
channels = [16]
batchsizes = [128]
batchsizes = [32]
kernelsizes = [100]
###############################
prog = '../src/train.py'
......@@ -46,10 +45,7 @@ echo "CUDA devices $CUDA_VISIBLE_DEVICES"
echo Job started at `date`\n'''
job +='python '+ str(bin) + ' '
for key, val in args.items():
if key in ['use_density_loss', 'use_potential_loss']:
if val: job += '--'+str(key) + ' '
else:
job += '--'+str(key) + ' '+ str(val) + ' '
job += '--'+str(key) + ' '+ str(val) + ' '
job += '''
echo Job finished at `date`\n'''
......
......@@ -56,17 +56,13 @@ if __name__=='__main__':
'b':b,
'lr':lr,
'folder':resfolder,
'use_potential_loss':use_potential_loss,
'use_density_loss':use_density_loss
'ptls':ptls,
'dnls':dnls
}
jobname = jobdir
for key, val in args.items():
if not(key in ['folder', 'restore_path', 'use_density_loss','use_potential_loss']):
if not(key in ['folder', 'restore_path']):
jobname += str(key) + str(val) + '_'
if key == 'use_density_loss' and val:
jobname += 'dnls' + '_'
if key == 'use_potential_loss' and val:
jobname += 'ptls' + '_'
jobname = jobname[:-1] # remove last '_'
......
......@@ -4,13 +4,15 @@ parser.add_argument("--folder", default='../data/',help="where to store results"
parser.add_argument("--restore_path", default=None, help="checkpoint file path")
#physical parameters
parser.add_argument("--N", type=int, default=1, help="particle number")
parser.add_argument("--N", type=int, default=2, help="particle number")
parser.add_argument("--G", type=int, default=300, help="number of grid points")
parser.add_argument("--batchsize", type=int, default=16, help="batchsize")
parser.add_argument("--epochs", type=int, default=100000, help="epochs")
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
parser.add_argument("--use_density_loss", action='store_true', help="")
parser.add_argument("--use_potential_loss", action='store_true', help="")
parser.add_argument("--enls", type=float, default=1.0, help="energy loss")
parser.add_argument("--ptls", type=float, default=0.0, help="potential loss")
parser.add_argument("--dnls", type=float, default=0.0, help="density loss")
parser.add_argument("--max_iter", type=int, default=100, help="maximum number of iterations in euler solver")
parser.add_argument("--epsilon", type=float, default=1e-8, help="converge criterion of density norm in euler solver")
......@@ -29,8 +31,7 @@ save_path += 'kejax' \
+ '_c' + str(args.channels) \
+ '_k' + str(args.kernelsize) \
+ '_b' + str(args.batchsize) \
+ '_lr' + str(args.lr)
if args.use_potential_loss:
save_path += '_ptls'
if args.use_density_loss:
save_path += '_dnls'
+ '_lr' + str(args.lr) \
+ '_en' + str(args.enls) \
+ '_pt' + str(args.ptls) \
+ '_dn' + str(args.dnls)
......@@ -11,11 +11,12 @@ def make_network(depth, channels, kernelsize):
x = jnp.expand_dims(x, axis=1)
for _ in range(depth):
resblock = hk.Sequential([
hk.Conv1D(output_channels=channels, kernel_shape=kernelsize, padding='SAME'),
jax.nn.softplus,
hk.Conv1D(output_channels=channels, kernel_shape=kernelsize, padding='SAME'),
hk.Conv1D(output_channels=channels, kernel_shape=kernelsize, padding='SAME',with_bias=False),
jax.nn.silu,
hk.Conv1D(output_channels=channels, kernel_shape=kernelsize, padding='SAME',with_bias=False),
])
x = jax.nn.softplus(x + resblock(x))
return hk.Conv1D(output_channels=1, kernel_shape=kernelsize, padding='SAME')(x).squeeze(axis=1)
x = jax.nn.silu(x + resblock(x))
x = hk.Conv1D(output_channels=1, kernel_shape=kernelsize, padding='SAME',with_bias=False)(x).squeeze(axis=1)
return jax.nn.silu(x)
return hk.transform(forward_fn)
......@@ -42,11 +42,11 @@ rho0 = jnp.mean(jnp.sqrt(rhos), axis=0)
density_loss, energy_loss, potential_loss, predict_fn, kinetic_energy_fn = make_loss(network, K, rho0, args.N, dx, args.max_iter, args.epsilon)
def loss_fn(params, vs, rhos, Es, mus):
total_loss = energy_loss(params, vs, rhos, Es, mus)
if args.use_potential_loss:
total_loss += potential_loss(params, vs, rhos, Es, mus)
if args.use_density_loss:
total_loss += density_loss(params, vs, rhos, Es, mus)
total_loss = args.enls * energy_loss(params, vs, rhos, Es, mus)
if args.ptls > 0:
total_loss += args.ptls * potential_loss(params, vs, rhos, Es, mus)
if args.dnls > 0:
total_loss += args.dnls * density_loss(params, vs, rhos, Es, mus)
return total_loss
opt_init, opt_update, get_params = optimizers.adam(step_size=args.lr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment