Advanced Computing Platform for Theoretical Physics

commit大文件会使得服务器变得不稳定,请大家尽量只commit代码,不要commit大的文件。

Commit 4162baf1 authored by Lei Wang's avatar Lei Wang
Browse files

looks like ctmrg with iteration is the fastest

parent 4ff4842c
import torch
from itertools import permutations
def ctmrg(T, d, Dcut, no_iter):
def ctmrg(T, d, Dcut, max_iter):
#symmetrize
T = (T + T.permute(3, 1, 2, 0))/2.
......@@ -11,15 +11,19 @@ def ctmrg(T, d, Dcut, no_iter):
lnZ = 0.0
truncation_error = 0.0
C = T.sum((0,1))
E = T.sum(1)
C = torch.randn(d, d, dtype=T.dtype, device=T.device) #T.sum((0,1))
E = torch.randn(d, d, d, dtype=T.dtype, device=T.device)#T.sum(1)
D = d
for n in range(no_iter):
sold = torch.zeros(d, dtype=T.dtype, device=T.device)
diff = 1E1
for n in range(max_iter):
A = torch.einsum('ab,eca,bdg,cdfh->efgh', (C, E, E, T)).contiguous().view(D*d, D*d)
A = (A+A.t())/2.
D_new = min(D*d, Dcut)
U, S, V = torch.svd(A)
s = S/S.max()
truncation_error += S[D_new:].sum()/S.sum()
P = U[:, :D_new] # projection operator
......@@ -41,11 +45,14 @@ def ctmrg(T, d, Dcut, no_iter):
D = D_new
maxval = C.max()
C = C/maxval
C = C/C.norm()
E = E/E.norm()
maxval = E.max()
E = E/maxval
if (s.numel() == sold.numel()):
diff = (s-sold).norm()
if (diff < 1E-8):
break
sold = s
Z1 = torch.einsum('ab,bcd,fd,gha,hcij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C))
#CEC = torch.einsum('da,ebd,ce->abc', (C,E,C)).view(1, D**2*d)
......@@ -55,21 +62,21 @@ def ctmrg(T, d, Dcut, no_iter):
Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C))
Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C))
print (' Z1, Z2, Z3:', Z1.item(), Z2.item(), Z3.item())
#print (' Z1, Z2, Z3:', Z1.item(), Z2.item(), Z3.item())
lnZ += torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs())
return lnZ, truncation_error
if __name__=='__main__':
torch.set_num_threads(4)
K = torch.tensor([0.44])
torch.set_num_threads(1)
K = torch.tensor([0.44], dtype=torch.float64)
Dcut = 180
n = 50
max_iter = 1000
c = torch.sqrt(torch.cosh(K)/2.)
s = torch.sqrt(torch.sinh(K)/2.)
M = torch.stack([torch.cat([c+s, c-s]), torch.cat([c-s, c+s])])
T = torch.einsum('ai,aj,ak,al->ijkl', (M, M, M, M))
lnZ, error = ctmrg(T, 2, Dcut, n)
lnZ, error = ctmrg(T, 2, Dcut, max_iter)
print (lnZ.item(), error)
......@@ -5,10 +5,12 @@ In a nutshell, it computes the maximum eigenvalue of tranfer matrix via variatio
[2] Levin and Nave, PRL 99, 120601 (2007)
'''
import torch
torch.set_num_threads(4)
torch.set_num_threads(1)
torch.manual_seed(42)
from vmps import vmps as contraction
from trg import levin_nave_trg as contraction
from ctmrg import ctmrg as contraction
#from vmps import vmps as contraction
if __name__=='__main__':
import time
......@@ -65,12 +67,14 @@ if __name__=='__main__':
#double layer
T2 = (A.t()@A).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
t0=time.time()
lnT = contraction(T1, D**2*d, Dcut, Niter, A1, lanczos_steps=args.lanczos_steps)
lnZ = contraction(T2, D**2, Dcut, Niter, A2, lanczos_steps=args.lanczos_steps)
#lnT = contraction(T1, D**2*d, Dcut, Niter, A1, lanczos_steps=args.lanczos_steps)
#lnZ = contraction(T2, D**2, Dcut, Niter, A2, lanczos_steps=args.lanczos_steps)
lnT, error1 = contraction(T1, D**2*d, Dcut, Niter)
lnZ, error2 = contraction(T2, D**2, Dcut, Niter)
loss = (-lnT + lnZ)
print (' contraction done {:.3f}s'.format(time.time()-t0))
print (' total loss', loss.item())
#print (' loss, error', loss.item(), error1.item(), error2.item())
print (' loss, error', loss.item(), error1.item(), error2.item())
t0=time.time()
loss.backward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment