Advanced Computing Platform for Theoretical Physics

Commit 1afdb8f6 authored by Lei Wang's avatar Lei Wang
Browse files

symmetrize T in vmps

parent a144be5b
...@@ -62,8 +62,8 @@ if __name__=='__main__': ...@@ -62,8 +62,8 @@ if __name__=='__main__':
T2 = (A.t()@A).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2) T2 = (A.t()@A).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
t0=time.time() t0=time.time()
lnT, error1 = contraction(T1, D**2*d, Dcut, Niter) lnT, _ = contraction(T1, D**2*d, Dcut, Niter)
lnZ, error2 = contraction(T2, D**2, Dcut, Niter) lnZ, _ = contraction(T2, D**2, Dcut, Niter)
loss = (-lnT + lnZ) loss = (-lnT + lnZ)
print (' contraction done {:.3f}s'.format(time.time()-t0)) print (' contraction done {:.3f}s'.format(time.time()-t0))
print (' total loss', loss.item()) print (' total loss', loss.item())
......
...@@ -28,11 +28,15 @@ def mpsrg(A, T): ...@@ -28,11 +28,15 @@ def mpsrg(A, T):
# C = torch.mm(C, C) # C = torch.mm(C, C)
return -lnZ1 + lnZ2 return -lnZ1 + lnZ2
def vmps(T, d, D, no_iter, Nepochs=5): def vmps(T, d, D, no_iter, Nepochs=50):
#symmetrize
T = (T + T.permute(3, 1, 2, 0))/2. #left-right
T = (T + T.permute(0, 2, 1, 3))/2. #up-down
A = torch.nn.Parameter(0.01*torch.randn(D, d, D, dtype=T.dtype, device=T.device)) A = torch.nn.Parameter(0.01*torch.randn(D, d, D, dtype=T.dtype, device=T.device))
optimizer = torch.optim.LBFGS([A], max_iter=10) optimizer = torch.optim.LBFGS([A], max_iter=20)
def closure(): def closure():
optimizer.zero_grad() optimizer.zero_grad()
#print ('einsum', time.time()- t0) #print ('einsum', time.time()- t0)
...@@ -40,7 +44,7 @@ def vmps(T, d, D, no_iter, Nepochs=5): ...@@ -40,7 +44,7 @@ def vmps(T, d, D, no_iter, Nepochs=5):
#t0 = time.time() #t0 = time.time()
loss = mpsrg(A, T.detach()) # loss = -lnZ , here we optimize over A loss = mpsrg(A, T.detach()) # loss = -lnZ , here we optimize over A
#print ('mpsrg', time.time()- t0) #print ('mpsrg', time.time()- t0)
print (' loss', loss.item()) #print (' loss', loss.item())
#t0 = time.time() #t0 = time.time()
loss.backward(retain_graph=False) loss.backward(retain_graph=False)
#print ('backward', time.time()- t0) #print ('backward', time.time()- t0)
...@@ -48,7 +52,7 @@ def vmps(T, d, D, no_iter, Nepochs=5): ...@@ -48,7 +52,7 @@ def vmps(T, d, D, no_iter, Nepochs=5):
for epoch in range(Nepochs): for epoch in range(Nepochs):
loss = optimizer.step(closure) loss = optimizer.step(closure)
print (' epoch, free energy', epoch, loss.item()) #print (' epoch, free energy', epoch, loss.item())
return -mpsrg(A.detach(), T), None # pass lnZ out, we need to optimize over T return -mpsrg(A.detach(), T), None # pass lnZ out, we need to optimize over T
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment