Advanced Computing Platform for Theoretical Physics

Commit 2bf5025b authored by Lei Wang's avatar Lei Wang
Browse files

clean up

parent 143e0e03
...@@ -8,9 +8,6 @@ import torch ...@@ -8,9 +8,6 @@ import torch
torch.set_num_threads(4) torch.set_num_threads(4)
torch.manual_seed(42) torch.manual_seed(42)
#from hotrg2 import hotrg as contraction
#from trg import levin_nave_trg as contraction
#from ctmrg import ctmrg as contraction
from vmps import vmps as contraction from vmps import vmps as contraction
if __name__=='__main__': if __name__=='__main__':
...@@ -19,7 +16,7 @@ if __name__=='__main__': ...@@ -19,7 +16,7 @@ if __name__=='__main__':
parser = argparse.ArgumentParser(description='') parser = argparse.ArgumentParser(description='')
parser.add_argument("-D", type=int, default=2, help="D") parser.add_argument("-D", type=int, default=2, help="D")
parser.add_argument("-Dcut", type=int, default=24, help="Dcut") parser.add_argument("-Dcut", type=int, default=24, help="Dcut")
parser.add_argument("-Niter", type=int, default=20, help="Niter") parser.add_argument("-Niter", type=int, default=50, help="Niter")
parser.add_argument("-float32", action='store_true', help="use float32") parser.add_argument("-float32", action='store_true', help="use float32")
parser.add_argument("-lanczos", action='store_true', help="lanczos") parser.add_argument("-lanczos", action='store_true', help="lanczos")
...@@ -36,7 +33,7 @@ if __name__=='__main__': ...@@ -36,7 +33,7 @@ if __name__=='__main__':
Dcut = args.Dcut Dcut = args.Dcut
Niter = args.Niter Niter = args.Niter
B = 0.1* torch.randn(d, D, D, D, D, dtype=dtype, device=device) B = 0.01* torch.randn(d, D, D, D, D, dtype=dtype, device=device)
#symmetrize initial boundary PEPS #symmetrize initial boundary PEPS
B = (B + B.permute(0, 4, 2, 3, 1))/2. B = (B + B.permute(0, 4, 2, 3, 1))/2.
B = (B + B.permute(0, 1, 3, 2, 4))/2. B = (B + B.permute(0, 1, 3, 2, 4))/2.
...@@ -45,8 +42,8 @@ if __name__=='__main__': ...@@ -45,8 +42,8 @@ if __name__=='__main__':
A = torch.nn.Parameter(B.view(d, D**4)) A = torch.nn.Parameter(B.view(d, D**4))
#boundary MPS #boundary MPS
A1 = torch.nn.Parameter(0.1*torch.randn(Dcut, D**2*d, Dcut, dtype=dtype, device=device)) A1 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2*d, Dcut, dtype=dtype, device=device))
A2 = torch.nn.Parameter(0.1*torch.randn(Dcut, D**2, Dcut, dtype=dtype, device=device)) A2 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2, Dcut, dtype=dtype, device=device))
#dimer covering #dimer covering
T = torch.zeros(d, d, d, d, d, d, dtype=dtype, device=device) T = torch.zeros(d, d, d, d, d, d, dtype=dtype, device=device)
......
...@@ -28,18 +28,6 @@ def mpsrg(A, T, use_lanczos=False): ...@@ -28,18 +28,6 @@ def mpsrg(A, T, use_lanczos=False):
w, _ = torch.symeig(C, eigenvectors=True) w, _ = torch.symeig(C, eigenvectors=True)
lnZ2 = torch.log(w.abs().max()) lnZ2 = torch.log(w.abs().max())
#lnZ1 = 0.0
#lnZ2 = 0.0
#for i in range(no_iter):
# s = B.norm(1)
# lnZ1 = lnZ1 + torch.log(s)/2**i
# B = B/s
# B = torch.mm(B, B)
# s = C.norm(1)
# lnZ2 = lnZ2 + torch.log(s)/2**i
# C = C/s
# C = torch.mm(C, C)
return -lnZ1 + lnZ2 return -lnZ1 + lnZ2
def vmps(T, d, D, Nepochs=50, Ainit=None, use_lanczos=False): def vmps(T, d, D, Nepochs=50, Ainit=None, use_lanczos=False):
...@@ -60,11 +48,11 @@ def vmps(T, d, D, Nepochs=50, Ainit=None, use_lanczos=False): ...@@ -60,11 +48,11 @@ def vmps(T, d, D, Nepochs=50, Ainit=None, use_lanczos=False):
#print ((B-B.t()).abs().sum(), (C-C.t()).abs().sum()) #print ((B-B.t()).abs().sum(), (C-C.t()).abs().sum())
t0 = time.time() t0 = time.time()
loss = mpsrg(A, T.detach(), use_lanczos) # loss = -lnZ , here we optimize over A loss = mpsrg(A, T.detach(), use_lanczos) # loss = -lnZ , here we optimize over A
print ('mpsrg', time.time()- t0) #print ('mpsrg', time.time()- t0)
print (' loss', loss.item()) #print (' loss', loss.item())
t0 = time.time() t0 = time.time()
loss.backward(retain_graph=False) loss.backward(retain_graph=False)
print ('backward', time.time()- t0) #print ('backward', time.time()- t0)
return loss return loss
for epoch in range(Nepochs): for epoch in range(Nepochs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment