Advanced Computing Platform for Theoretical Physics

commit大文件会使得服务器变得不稳定,请大家尽量只commit代码,不要commit大的文件。

Commit f2593182 authored by Lei Wang's avatar Lei Wang
Browse files

make ising work

parent 306844bb
......@@ -4,9 +4,7 @@ import torch
torch.set_num_threads(4)
torch.manual_seed(42)
#from hotrg2 import hotrg as contraction
#from trg import levin_nave_trg as contraction
from ctmrg import ctmrg as contraction
from vmps import vmps as contraction
if __name__=='__main__':
import time
......@@ -19,6 +17,7 @@ if __name__=='__main__':
parser.add_argument("-Nepochs", type=int, default=100, help="Nepochs")
parser.add_argument("-float32", action='store_true', help="use float32")
parser.add_argument("-lanczos_steps", type=int, default=0, help="lanczos steps")
parser.add_argument("-cuda", type=int, default=-1, help="use GPU")
args = parser.parse_args()
device = torch.device("cpu" if args.cuda<0 else "cuda:"+str(args.cuda))
......@@ -29,7 +28,7 @@ if __name__=='__main__':
D = args.D
Dcut = args.Dcut
Niter = args.Niter
beta = torch.tensor([args.beta], dtype=dtype, device=device).requires_grad_()
beta = torch.tensor([args.beta], dtype=dtype, device=device)
B = 0.01* torch.randn(d, D, D, D, D, dtype=dtype, device=device)
#symmetrize
......@@ -39,18 +38,21 @@ if __name__=='__main__':
B = (B + B.permute(0, 2, 1, 4, 3))/2.
A = torch.nn.Parameter(B.view(d, D**4))
#boundary MPS
A1 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2*d, Dcut, dtype=dtype, device=device))
A2 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2, Dcut, dtype=dtype, device=device))
#3D Ising
c = torch.sqrt(torch.cosh(beta)/2.)
s = torch.sqrt(torch.sinh(beta)/2.)
M = torch.stack([torch.cat([c+s, c-s]), torch.cat([c-s, c+s])])
T = torch.einsum('ai,aj,ak,al,am,an->ijklmn', (M, M, M, M, M, M))
T = T.view(d, d**4, d)
print (T)
#print (T)
optimizer = torch.optim.LBFGS([A], max_iter=20)
def closure():
if beta.grad is not None: beta.grad.zero_()
optimizer.zero_grad()
T1 = torch.einsum('xa,xby,yc' , (A,T,A)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d)
......@@ -58,11 +60,11 @@ if __name__=='__main__':
T2 = (A.t()@A).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
t0=time.time()
lnT1, error1 = contraction(T1, D**2*d, Dcut, Niter)
lnT2, error2 = contraction(T2, D**2, Dcut, Niter)
lnT = contraction(T1, D**2*d, Dcut, Niter, A1, lanczos_steps=args.lanczos_steps)
lnZ = contraction(T2, D**2, Dcut, Niter, A2, lanczos_steps=args.lanczos_steps)
loss = (-lnT1 + lnT2) # loss = -lnZ of Ising
print (' contraction done {:.3f}s'.format(time.time()-t0))
print (' loss, error', loss.item(), error1.item(), error2.item())
print (' total error', loss.item())
t0=time.time()
loss.backward(retain_graph=True)
......@@ -71,5 +73,4 @@ if __name__=='__main__':
for epoch in range(args.Nepochs):
loss = optimizer.step(closure)
En = beta.grad.item() # En = -d lnZ / d beta
print ('epoch, free energy, energy', epoch, loss.item(), En)
print ('epoch, free energy, energy', epoch, loss.item())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment