Advanced Computing Platform for Theoretical Physics

Commit 5d8e8921 authored by Lei Wang's avatar Lei Wang
Browse files

simplify symmetrize;ctmrg for ising

parent 4aefaec8
import torch import torch
from itertools import permutations from itertools import permutations
def ctmrg(Tin, d, Dcut, no_iter): def ctmrg(T, d, Dcut, no_iter):
#symmetrize #symmetrize
T = torch.zeros_like(Tin) T = (T + T.permute(3, 1, 2, 0) + T.permute(0, 2, 1, 3) + T.permute(2, 3, 0, 1) + T.permute(1, 0, 3, 2))/5.
for i in permutations([0, 1, 2, 3], 4):
T = T + Tin.permute(i)
T = T/24.
lnZ = 0.0 lnZ = 0.0
truncation_error = 0.0 truncation_error = 0.0
...@@ -55,7 +52,7 @@ def ctmrg(Tin, d, Dcut, no_iter): ...@@ -55,7 +52,7 @@ def ctmrg(Tin, d, Dcut, no_iter):
Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C)) Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C))
Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C)) Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C))
print ('Z1, Z2, Z3:', Z1.item(), Z2.item(), Z3.item()) print (' Z1, Z2, Z3:', Z1.item(), Z2.item(), Z3.item())
lnZ += torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs()) lnZ += torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs())
return lnZ, truncation_error return lnZ, truncation_error
......
...@@ -35,9 +35,7 @@ if __name__=='__main__': ...@@ -35,9 +35,7 @@ if __name__=='__main__':
B = 0.01* torch.randn(d, D, D, D, D, dtype=dtype, device=device) B = 0.01* torch.randn(d, D, D, D, D, dtype=dtype, device=device)
#symmetrize #symmetrize
A = torch.zeros_like(B) A = (B + B.permute(0, 4, 2, 3, 1) + B.permute(0, 1, 3, 2, 4) + B.permute(0, 3, 4, 1, 2) + B.permute(0, 2, 1, 4, 3))/5.
for i in permutations([1, 2, 3, 4], 4):
A = A + B.permute([0]+list(i))/24.
A = A.view(d, D**4) A = A.view(d, D**4)
A = torch.nn.Parameter(A) A = torch.nn.Parameter(A)
......
...@@ -4,8 +4,9 @@ import torch ...@@ -4,8 +4,9 @@ import torch
torch.set_num_threads(4) torch.set_num_threads(4)
torch.manual_seed(42) torch.manual_seed(42)
from hotrg2 import hotrg as contraction #from hotrg2 import hotrg as contraction
#from trg import levin_nave_trg as contraction #from trg import levin_nave_trg as contraction
from ctmrg import ctmrg as contraction
if __name__=='__main__': if __name__=='__main__':
import time import time
...@@ -29,7 +30,11 @@ if __name__=='__main__': ...@@ -29,7 +30,11 @@ if __name__=='__main__':
Niter = args.Niter Niter = args.Niter
beta = torch.tensor([args.beta], dtype=dtype, device=device).requires_grad_() beta = torch.tensor([args.beta], dtype=dtype, device=device).requires_grad_()
A = torch.nn.Parameter(0.01* torch.randn(d, D**4, dtype=dtype, device=device, requires_grad=True)) B = 0.01* torch.randn(d, D, D, D, D, dtype=dtype, device=device)
#symmetrize
A = (B + B.permute(0, 4, 2, 3, 1) + B.permute(0, 1, 3, 2, 4) + B.permute(0, 3, 4, 1, 2) + B.permute(0, 2, 1, 4, 3))/5.
A = A.view(d, D**4)
A = torch.nn.Parameter(A)
#3D Ising #3D Ising
c = torch.sqrt(torch.cosh(beta)) c = torch.sqrt(torch.cosh(beta))
...@@ -38,9 +43,10 @@ if __name__=='__main__': ...@@ -38,9 +43,10 @@ if __name__=='__main__':
T = torch.einsum('ai,aj,ak,al,am,an->ijklmn', (M, M, M, M, M, M)) T = torch.einsum('ai,aj,ak,al,am,an->ijklmn', (M, M, M, M, M, M))
T = T.view(d, d**4, d) T = T.view(d, d**4, d)
optimizer = torch.optim.LBFGS([A], max_iter=50) optimizer = torch.optim.LBFGS([A], max_iter=20)
def closure(): def closure():
if beta.grad is not None: beta.grad.zero_()
optimizer.zero_grad() optimizer.zero_grad()
T1 = torch.einsum('xa,xby,yc' , (A,T,A)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d) T1 = torch.einsum('xa,xby,yc' , (A,T,A)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d)
...@@ -50,15 +56,16 @@ if __name__=='__main__': ...@@ -50,15 +56,16 @@ if __name__=='__main__':
t0=time.time() t0=time.time()
lnT1, error1 = contraction(T1, D**2*d, Dcut, Niter) lnT1, error1 = contraction(T1, D**2*d, Dcut, Niter)
lnT2, error2 = contraction(T2, D**2, Dcut, Niter) lnT2, error2 = contraction(T2, D**2, Dcut, Niter)
loss = (-lnT1 + lnT2)/2**Niter # loss = -lnZ of Ising loss = (-lnT1 + lnT2) # loss = -lnZ of Ising
print ('contraction done {:.3f}s'.format(time.time()-t0)) print (' contraction done {:.3f}s'.format(time.time()-t0))
print ('truncation error', error1.item(), error2.item()) print (' loss, error', loss.item(), error1.item(), error2.item())
t0=time.time() t0=time.time()
loss.backward(retain_graph=True) loss.backward(retain_graph=True)
print ('backward done {:.3f}s'.format(time.time()-t0)) print (' backward done {:.3f}s'.format(time.time()-t0))
print ('free energy, energy', loss.item(), beta.grad.item()) # En = -d lnZ / d beta
beta.grad.zero_()
return loss return loss
optimizer.step(closure) for epoch in range(100):
loss = optimizer.step(closure)
En = beta.grad.item() # En = -d lnZ / d beta
print ('epoch, free energy, energy', epoch, loss.item(), En)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment