Advanced Computing Platform for Theoretical Physics

Commit 43862c5f authored by Lei Wang's avatar Lei Wang
Browse files

added ctmrg, compute lnZ using Z1*Z3/Z2^2

parent e1ba1120
import torch
def ctmrg(T, d, Dcut, no_iter):
lnZ = 0.0
truncation_error = 0.0
C = T[0, 0, :, :]
E = T[:, 0, :, :]
D = d
for n in range(no_iter):
maxval = C.max()
C = C/maxval
maxval = E.max()
E = E/maxval
A = torch.einsum('ab,eca,bdg,cdfh->efgh', (C, E, E, T)).contiguous().view(D*d, D*d)
D_new = min(D*d, Dcut)
U, S, V = torch.svd(A)
truncation_error += S[D_new:].sum()/S.sum()
P = U[:, :D_new] # projection operator
C = (P.t() @ A @ P) #(D, D)
#C = (C+C.t())/2.
ET = torch.einsum('ldr,adbc->labrc', (E, T)).contiguous().view(D*d, d, D*d)
E = torch.einsum('li,ldr,rj->idj', (P, ET, P)) #(D, d, D)
D = D_new
Z1 = torch.einsum('ab,bcd,fd,gha,hcij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C))
Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C))
Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C))
lnZ += torch.log(Z1) + torch.log(Z3) - 2.*torch.log(Z2)
return lnZ, truncation_error
if __name__=='__main__':
K = torch.tensor([0.44])
Dcut = 20
n = 20
#Boltzmann factor on a bond M=LR^T
M = torch.stack([torch.cat([torch.exp(K), torch.exp(-K)]),
torch.cat([torch.exp(-K), torch.exp(K)])
])
U, S, V = torch.svd(M)
L = U*torch.sqrt(S)
R = V*torch.sqrt(S)
# L
# |
# T = R^{T}-o-L
# |
# R^{T}
T = torch.einsum('ai,aj,ak,al->ijkl', (L, L, R, R))
lnZ, error = ctmrg(T, 2, Dcut, n)
print (lnZ.item(), error)
......@@ -8,8 +8,9 @@ import torch
torch.set_num_threads(4)
torch.manual_seed(42)
from hotrg2 import hotrg as contraction
#from hotrg2 import hotrg as contraction
#from trg import levin_nave_trg as contraction
from ctmrg import ctmrg as contraction
if __name__=='__main__':
import time
......@@ -42,7 +43,8 @@ if __name__=='__main__':
T[1, 0, 0, 0, 0, 0] = 1.0
T = T.view(d, d**4, d)
optimizer = torch.optim.LBFGS([A], max_iter=50)
#optimizer = torch.optim.LBFGS([A], max_iter=10)
optimizer = torch.optim.Adam([A])
def closure():
optimizer.zero_grad()
......@@ -54,7 +56,7 @@ if __name__=='__main__':
lnT, error1 = contraction(T1, D**2*d, Dcut, Niter)
lnZ, error2 = contraction(T2, D**2, Dcut, Niter)
loss = (-lnT + lnZ)/2**Niter
loss = (-lnT + lnZ)
print ('contraction done {:.3f}s'.format(time.time()-t0))
print ('residual entropy', -loss.item(), error1.item(), error2.item())
......@@ -62,5 +64,7 @@ if __name__=='__main__':
loss.backward()
print ('backward done {:.3f}s'.format(time.time()-t0))
return loss
optimizer.step(closure)
for epoch in range(100):
loss = optimizer.step(closure)
#print ('epoch, loss', epoch, loss)
......@@ -41,7 +41,7 @@ def hotrg(T, D, Dcut, no_iter):
for y in range(T.shape[1]):
trace += T[x, y, y, x]
lnZ += torch.log(trace)
return lnZ, truncation_error
return lnZ/2**no_iter, truncation_error
if __name__=='__main__':
K = torch.tensor([0.44])
......@@ -64,4 +64,4 @@ if __name__=='__main__':
T = torch.einsum('ai,aj,ak,al->ijkl', (L, L, R, R))
lnZ, _ = hotrg(T, 2, Dcut, n)
print (lnZ.item()/2**n)
print (lnZ.item())
......@@ -46,7 +46,7 @@ def hotrg(T, D, Dcut, no_iter):
for y in range(T.shape[1]):
trace += T[x, y, y, x]
lnZ += torch.log(trace)
return lnZ, truncation_error
return lnZ/2**no_iter, truncation_error
if __name__=='__main__':
import time
......@@ -70,4 +70,4 @@ if __name__=='__main__':
T = torch.einsum('ai,aj,ak,al->ijkl', (R, L, R, L))
lnZ, _ = hotrg(T, 2, Dcut, n)
print (lnZ.item()/2**n)
print (lnZ.item())
......@@ -10,6 +10,7 @@ def levin_nave_trg(T, D, Dcut, no_iter):
maxval = T.max()
T = T/maxval
lnZ += 2**(no_iter-n)*torch.log(maxval)
#print (2**(no_iter-n)*torch.log(maxval)/2**(no_iter))
Ma = T.permute(2, 1, 0, 3).contiguous().view(D**2, D**2)
Mb = T.permute(3, 2, 1, 0).contiguous().view(D**2, D**2)
......@@ -25,7 +26,9 @@ def levin_nave_trg(T, D, Dcut, no_iter):
S2 = (Ub[:, :D_new]* torch.sqrt(Sb[:D_new])).view(D, D, D_new)
S4 = (Vb[:, :D_new]* torch.sqrt(Sb[:D_new])).view(D, D, D_new)
T_new = torch.einsum('war,abu,bgl,gwd->ruld', (S1, S2, S3, S4))
#T_new = torch.einsum('war,abu,bgl,gwd->ruld', (S1, S2, S3, S4))
T_new = torch.tensordot( torch.tensordot(S1, S2, dims=([1], [0])),
torch.tensordot(S3, S4, dims=([1], [0])), dims=([2, 0], [0, 2]))
D = D_new
T = T_new
......@@ -34,9 +37,10 @@ def levin_nave_trg(T, D, Dcut, no_iter):
for x in range(D):
for y in range(D):
trace += T[x, y, y, x]
#print (torch.log(trace)/2**no_iter)
lnZ += torch.log(trace)
return lnZ, truncation_error
return lnZ/2**no_iter, truncation_error
if __name__=='__main__':
K = torch.tensor([0.44])
......@@ -58,5 +62,5 @@ if __name__=='__main__':
# R^{T}
T = torch.einsum('ai,aj,ak,al->ijkl', (L, L, R, R))
lnZ, _ = levin_nave_trg(T, 2, Dcut, n)
print (lnZ.item()/2**n)
lnZ, error = levin_nave_trg(T, 2, Dcut, n)
print (lnZ.item(), error)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment