Advanced Computing Platform for Theoretical Physics

Commit 5e342755 authored by Lei Wang's avatar Lei Wang
Browse files

not

parent fface86d
...@@ -22,7 +22,6 @@ def symmetrize(A): ...@@ -22,7 +22,6 @@ def symmetrize(A):
return As/As.norm() return As/As.norm()
if __name__=='__main__': if __name__=='__main__':
import time
import argparse import argparse
parser = argparse.ArgumentParser(description='') parser = argparse.ArgumentParser(description='')
parser.add_argument("-D", type=int, default=2, help="D") parser.add_argument("-D", type=int, default=2, help="D")
...@@ -72,28 +71,22 @@ if __name__=='__main__': ...@@ -72,28 +71,22 @@ if __name__=='__main__':
T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d) T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d)
#double layer #double layer
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2) T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
t0=time.time()
#lnT = contraction(T1, D**2*d, Dcut, Niter, A1, lanczos_steps=args.lanczos_steps) #lnT = contraction(T1, D**2*d, Dcut, Niter, A1, lanczos_steps=args.lanczos_steps)
#lnZ = contraction(T2, D**2, Dcut, Niter, A2, lanczos_steps=args.lanczos_steps) #lnZ = contraction(T2, D**2, Dcut, Niter, A2, lanczos_steps=args.lanczos_steps)
lnT, error1 = contraction(T1, D**2*d, Dcut, Niter) lnT, error1 = contraction(T1, D**2*d, Dcut, Niter)
lnZ, error2 = contraction(T2, D**2, Dcut, Niter) lnZ, error2 = contraction(T2, D**2, Dcut, Niter)
loss = (-lnT + lnZ) loss = (-lnT + lnZ)
#print (' contraction done {:.3f}s'.format(time.time()-t0))
#print (' total loss', loss.item()) #print (' total loss', loss.item())
#print (' loss, error', loss.item(), error1.item(), error2.item()) #print (' loss, error', loss.item(), error1.item(), error2.item())
t0=time.time()
loss.backward() loss.backward()
#print (' backward done {:.3f}s'.format(time.time()-t0))
return loss return loss
for epoch in range(args.Nepochs): for epoch in range(args.Nepochs):
loss = optimizer.step(closure) loss = optimizer.step(closure)
print ('epoch, residual entropy', epoch, -loss.item()) print ('epoch, loss, gradnorm:', epoch, loss.item(), A.grad.norm().item())
#continue with Newton-CG #continue with Newton-CG
####################################################################3 ####################################################################3
def fun(x, info): def fun(x, info):
...@@ -102,7 +95,6 @@ if __name__=='__main__': ...@@ -102,7 +95,6 @@ if __name__=='__main__':
T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d) T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d)
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2) T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
t0=time.time()
lnT, error1 = contraction(T1, D**2*d, Dcut, Niter) lnT, error1 = contraction(T1, D**2*d, Dcut, Niter)
lnZ, error2 = contraction(T2, D**2, Dcut, Niter) lnZ, error2 = contraction(T2, D**2, Dcut, Niter)
loss = (-lnT + lnZ) loss = (-lnT + lnZ)
...@@ -121,16 +113,16 @@ if __name__=='__main__': ...@@ -121,16 +113,16 @@ if __name__=='__main__':
A = info['A'] A = info['A']
if not ((x == A.detach().numpy().ravel()).all()): if ((x == A.detach().numpy().ravel()).all()):
grad = A.grad.view(-1) grad = A.grad.view(-1)
else: else:
print ('recalculate forward and grad')
A = torch.as_tensor(A).requires_grad_() A = torch.as_tensor(A).requires_grad_()
As = symmetrize(A) As = symmetrize(A)
T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d) T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d)
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2) T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
t0=time.time()
lnT, error1 = contraction(T1, D**2*d, Dcut, Niter) lnT, error1 = contraction(T1, D**2*d, Dcut, Niter)
lnZ, error2 = contraction(T2, D**2, Dcut, Niter) lnZ, error2 = contraction(T2, D**2, Dcut, Niter)
loss = (-lnT + lnZ) loss = (-lnT + lnZ)
...@@ -141,7 +133,7 @@ if __name__=='__main__': ...@@ -141,7 +133,7 @@ if __name__=='__main__':
#dot it with the given vector #dot it with the given vector
loss = torch.dot(grad, torch.as_tensor(p)) loss = torch.dot(grad, torch.as_tensor(p))
hvp = torch.autograd.grad(loss, A)[0].view(-1) hvp = torch.autograd.grad(loss, A, retain_graph=True)[0].view(-1)
return hvp.numpy().ravel() return hvp.numpy().ravel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment