Advanced Computing Platform for Theoretical Physics

commit大文件会使得服务器变得不稳定,请大家尽量只commit代码,不要commit大的文件。

Commit a813c8cb authored by Lei Wang's avatar Lei Wang

use tensornets; enable checkpoint

parent e3f4b90c
import torch
from tensornets import CTMRG
def ctmrg(T, chi, maxiter, epsilon, use_checkpoint=False):
C, E = CTMRG(T, chi, maxiter, epsilon, use_checkpoint=use_checkpoint)
Z1 = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C))
Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C))
Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C))
lnZ = torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs())
return lnZ
def symmetrize(A):
'''
A(phy, up, left, down, right)
left-right, up-down, diagonal symmetrize
'''
Asymm = (A + A.permute(0, 1, 4, 3, 2))/2. # left-right symmetry
Asymm = (Asymm + Asymm.permute(0, 3, 2, 1, 4))/2. # up-down symmetry
Asymm = (Asymm + Asymm.permute(0, 4, 3, 2, 1))/2. # skew-diagonal symmetry
Asymm = (Asymm + Asymm.permute(0, 2, 1, 4, 3))/2. # diagonal symmetry
return Asymm/Asymm.norm()
......@@ -8,50 +8,31 @@ import torch
torch.set_num_threads(1)
torch.manual_seed(42)
#from trg import levin_nave_trg as contraction
from ctmrg import ctmrg as contraction
#from vmps import vmps as contraction
def symmetrize(A):
As = A.view(d, D, D, D, D)
As = (As + As.permute(0, 4, 2, 3, 1))/2.
As = (As + As.permute(0, 1, 3, 2, 4))/2.
As = (As + As.permute(0, 3, 4, 1, 2))/2.
As = (As + As.permute(0, 2, 1, 4, 3))/2.
As = As.view(d, D**4)
return As/As.norm()
from contraction import symmetrize, ctmrg
if __name__=='__main__':
import argparse
parser = argparse.ArgumentParser(description='')
parser.add_argument("-D", type=int, default=2, help="D")
parser.add_argument("-Dcut", type=int, default=30, help="Dcut")
parser.add_argument("-Niter", type=int, default=50, help="Niter")
parser.add_argument("-Nepochs", type=int, default=10, help="Nepochs")
parser.add_argument("-chi", type=int, default=30, help="chi")
parser.add_argument("-maxiter", type=int, default=50, help="maxiter")
parser.add_argument("-epsilon", type=float, default=1E-10, help="maxiter")
parser.add_argument("-nepochs", type=int, default=10, help="nepochs")
parser.add_argument("-float32", action='store_true', help="use float32")
parser.add_argument("-lanczos_steps", type=int, default=0, help="lanczos steps")
parser.add_argument("-use_checkpoint", action='store_true', help="use checkpoint")
parser.add_argument("-cuda", type=int, default=-1, help="use GPU")
args = parser.parse_args()
device = torch.device("cpu" if args.cuda<0 else "cuda:"+str(args.cuda))
dtype = torch.float32 if args.float32 else torch.float64
if args.lanczos_steps>0: print ('lanczos steps', args.lanczos_steps)
d = 2 # fixed
D = args.D
Dcut = args.Dcut
Niter = args.Niter
B = 0.01* torch.randn(d, D, D, D, D, dtype=dtype, device=device)
#symmetrize initial boundary PEPS
A = torch.nn.Parameter(symmetrize(B).view(d, D**4))
A = torch.nn.Parameter(symmetrize(B))
#boundary MPS
#A1 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2*d, Dcut, dtype=dtype, device=device))
#A2 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2, Dcut, dtype=dtype, device=device))
#dimer covering
T = torch.zeros(d, d, d, d, d, d, dtype=dtype, device=device)
T[0, 0, 0, 0, 0, 1] = 1.0
......@@ -66,23 +47,20 @@ if __name__=='__main__':
def closure():
optimizer.zero_grad()
As = symmetrize(A)
As = symmetrize(A).view(d, D**4)
T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d)
#double layer
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
#lnT = contraction(T1, D**2*d, Dcut, Niter, A1, lanczos_steps=args.lanczos_steps)
#lnZ = contraction(T2, D**2, Dcut, Niter, A2, lanczos_steps=args.lanczos_steps)
lnT, error1 = contraction(T1, D**2*d, Dcut, Niter)
lnZ, error2 = contraction(T2, D**2, Dcut, Niter)
lnT = ctmrg(T1, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
lnZ = ctmrg(T2, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
loss = (-lnT + lnZ)
#print (' total loss', loss.item())
#print (' loss, error', loss.item(), error1.item(), error2.item())
loss.backward()
return loss
for epoch in range(args.Nepochs):
for epoch in range(args.nepochs):
loss = optimizer.step(closure)
print ('epoch, loss, gradnorm:', epoch, loss.item(), A.grad.norm().item())
......@@ -90,13 +68,14 @@ if __name__=='__main__':
####################################################################3
def fun(x, info):
A = torch.as_tensor(x, device=device).requires_grad_()
As = symmetrize(A)
A = torch.as_tensor(x, device=device).view(d, D, D, D, D).requires_grad_()
As = symmetrize(A).view(d, D**4)
T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d)
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
lnT, error1 = contraction(T1, D**2*d, Dcut, Niter)
lnZ, error2 = contraction(T2, D**2, Dcut, Niter)
lnT = ctmrg(T1, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
lnZ = ctmrg(T2, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
loss = (-lnT + lnZ)
loss.backward(create_graph=True)
......@@ -112,30 +91,31 @@ if __name__=='__main__':
def hessp(x, p, info):
A = info['A']
if ((x == A.detach().cpu().numpy().ravel()).all()):
grad = A.grad.view(-1)
else:
if not ((x == A.detach().cpu().numpy().ravel()).all()):
print ('recalculate forward and grad')
A = torch.as_tensor(A, device=device).requires_grad_()
A = torch.as_tensor(x, device=device).view(d, D, D, D, D).requires_grad_()
As = symmetrize(A)
As = symmetrize(A).view(d, D**4)
T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d)
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
lnT, error1 = contraction(T1, D**2*d, Dcut, Niter)
lnZ, error2 = contraction(T2, D**2, Dcut, Niter)
lnT = ctmrg(T1, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
lnZ = ctmrg(T2, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
loss = (-lnT + lnZ)
loss.backward(create_graph=True)
grad = A.grad.view(-1)
Agrad = A.grad.data.clone() #take A.grad.data out
#dot it with the given vector
loss = torch.dot(grad, torch.as_tensor(p, device=device))
hvp = torch.autograd.grad(loss, A, retain_graph=True)[0].view(-1)
loss = A.grad.view(-1)@torch.as_tensor(p, device=device)
A.grad.data.zero_()
loss.backward(retain_graph=True)
res = A.grad.detach().cpu().numpy().ravel()
A.grad.data = Agrad #put it back
return res
return hvp.cpu().numpy().ravel()
import scipy.optimize
x0 = A.detach().cpu().numpy().ravel()
......
import torch
from .adlib import SVD
svd = SVD.apply
from .ctmrg import CTMRG
from .vumps import vumps_run, vumps_calc
def projection(T, epsilon=1E-3):
D = T.shape[0] # double layer bond dimension
M = T.view(D, -1)
M = M@M.t()
U, S, _ = svd(M)
#S = S/S.max()
#chi = (S>epsilon).sum().item()
#up to truncation error
chi = (torch.cumsum(S, dim=0)/S.sum() <= 1-epsilon).sum().item()
U = U[:, :chi].view(D, chi)
#print (S/S.max())
#print (torch.cumsum(S, dim=0)/S.sum() )
print ('---->truncated from', D, 'to', chi)
return torch.einsum('abcd,ai,bj,ck,dl->ijkl', (T, U, U, U, U)), U
def ctmrg(T, chi, maxiter, epsilon):
with torch.no_grad():
C, E = CTMRG(T, chi, maxiter, epsilon)
Z1 = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C))
Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C))
Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C))
lnZ = torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs())
return lnZ
def vumps(T, chi, maxiter, epsilon):
with torch.no_grad():
_, AC, F, C, _ = vumps_run(T, chi, epsilon, maxiter)
lnZ = torch.log(vumps_calc(T, AC, F, C))
return lnZ
from .ctmrg import CTMRG
from .singlelayer import projection, ctmrg, vumps
from .utils import save_checkpoint, load_checkpoint, kronecker_product, symmetrize
from .svd import SVD
from .eigh import EigenSolver
from .qr import QR
'''
PyTorch has its own implementation of backward function for symmetric eigensolver https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1660
However, it assumes a triangular adjoint.
We reimplement it to return a symmetric adjoint
'''
import numpy as np
import torch
class EigenSolver(torch.autograd.Function):
@staticmethod
def forward(self, A):
w, v = torch.symeig(A, eigenvectors=True)
self.save_for_backward(w, v)
return w, v
@staticmethod
def backward(self, dw, dv):
w, v = self.saved_tensors
dtype, device = w.dtype, w.device
N = v.shape[0]
F = w - w[:,None]
F.diagonal().fill_(np.inf)
# safe inverse
msk = (torch.abs(F) < 1e-20)
F[msk] += 1e-20
F = 1./F
vt = v.t()
vdv = vt@dv
return v@(torch.diag(dw) + F*(vdv-vdv.t())/2) @vt
import torch
from torch.utils.checkpoint import detach_variable
def step(A, x):
y = A@x
y = y[0].sign() * y
return y/y.norm()
class FixedPoint(torch.autograd.Function):
@staticmethod
def forward(ctx, A, x0, tol):
x, x_prev = step(A, x0), x0
while torch.dist(x, x_prev) > tol:
x, x_prev = step(A, x), x
ctx.save_for_backward(A, x)
ctx.tol = tol
return x
@staticmethod
def backward(ctx, grad):
A, x = detach_variable(ctx.saved_tensors)
dA = grad
while True:
with torch.enable_grad():
grad = torch.autograd.grad(step(A, x), x, grad_outputs=grad)[0]
if (torch.norm(grad) > ctx.tol):
dA = dA + grad
else:
break
with torch.enable_grad():
dA = torch.autograd.grad(step(A, x), A, grad_outputs=dA)[0]
return dA, None, None
import torch
class QR(torch.autograd.Function):
@staticmethod
def forward(self, A):
Q, R = torch.qr(A)
self.save_for_backward(A, Q, R)
return Q, R
@staticmethod
def backward(self, dq, dr):
A, q, r = self.saved_tensors
if r.shape[0] == r.shape[1]:
return _simple_qr_backward(q, r, dq ,dr)
M, N = r.shape
B = A[:,M:]
dU = dr[:,:M]
dD = dr[:,M:]
U = r[:,:M]
da = _simple_qr_backward(q, U, dq+B@dD.t(), dU)
db = q@dD
return torch.cat([da, db], 1)
def _simple_qr_backward(q, r, dq, dr):
if r.shape[-2] != r.shape[-1]:
raise NotImplementedError("QrGrad not implemented when ncols > nrows "
"or full_matrices is true and ncols != nrows.")
qdq = q.t() @ dq
qdq_ = qdq - qdq.t()
rdr = r @ dr.t()
rdr_ = rdr - rdr.t()
tril = torch.tril(qdq_ + rdr_)
def _TriangularSolve(x, r):
"""Equiv to x @ torch.inverse(r).t() if r is upper-tri."""
res = torch.triangular_solve(x.t(), r, upper=True, transpose=False)[0].t()
return res
grad_a = q @ (dr + _TriangularSolve(tril, r))
grad_b = _TriangularSolve(dq - q @ qdq, r)
return grad_a + grad_b
'''
PyTorch has its own implementation of backward function for SVD https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1577
We reimplement it with a safe inverse function in light of degenerated singular values
'''
import numpy as np
import torch
def safe_inverse(x, epsilon=1E-12):
return x/(x**2 + epsilon)
class SVD(torch.autograd.Function):
@staticmethod
def forward(self, A):
U, S, V = torch.svd(A)
self.save_for_backward(U, S, V)
return U, S, V
@staticmethod
def backward(self, dU, dS, dV):
U, S, V = self.saved_tensors
Vt = V.t()
Ut = U.t()
M = U.size(0)
N = V.size(0)
NS = len(S)
F = (S - S[:, None])
F = safe_inverse(F)
F.diagonal().fill_(0)
G = (S + S[:, None])
G.diagonal().fill_(np.inf)
G = 1/G
UdU = Ut @ dU
VdV = Vt @ dV
Su = (F+G)*(UdU-UdU.t())/2
Sv = (F-G)*(VdV-VdV.t())/2
dA = U @ (Su + Sv + torch.diag(dS)) @ Vt
if (M>NS):
dA = dA + (torch.eye(M, dtype=dU.dtype, device=dU.device) - U@Ut) @ (dU/S) @ Vt
if (N>NS):
dA = dA + (U/S) @ dV.t() @ (torch.eye(N, dtype=dU.dtype, device=dU.device) - V@Vt)
return dA
import torch
from torch.utils.checkpoint import checkpoint
from .adlib import SVD
svd = SVD.apply
#from .adlib import EigenSolver
#symeig = EigenSolver.apply
def renormalize(*tensors):
# T(up,left,down,right), u=up, l=left, d=down, r=right
# C(d,r), EL(u,r,d), EU(l,d,r)
C, E, T, chi = tensors
dimT, dimE = T.shape[0], E.shape[0]
D_new = min(dimE*dimT, chi)
# step 1: contruct the density matrix Rho
Rho = torch.tensordot(C,E,([1],[0])) # C(ef)*EU(fga)=Rho(ega)
Rho = torch.tensordot(Rho,E,([0],[0])) # Rho(ega)*EL(ehc)=Rho(gahc)
Rho = torch.tensordot(Rho,T,([0,2],[0,1])) # Rho(gahc)*T(ghdb)=Rho(acdb)
Rho = Rho.permute(0,3,1,2).contiguous().view(dimE*dimT, dimE*dimT) # Rho(acdb)->Rho(ab;cd)
Rho = Rho+Rho.t()
Rho = Rho/Rho.norm()
# step 2: Get Isometry P
U, S, V = torch.svd(Rho)
truncation_error = S[D_new:].sum()/S.sum()
P = U[:, :D_new] # projection operator
#can also do symeig since Rho is symmetric
#S, U = symeig(Rho)
#sorted, indices = torch.sort(S.abs(), descending=True)
#truncation_error = sorted[D_new:].sum()/sorted.sum()
#S = S[indices][:D_new]
#P = U[:, indices][:, :D_new] # projection operator
# step 3: renormalize C and E
C = (P.t() @ Rho @ P) #C(D_new, D_new)
## EL(u,r,d)
P = P.view(dimE,dimT,D_new)
E = torch.tensordot(E, P, ([0],[0])) # EL(def)P(dga)=E(efga)
E = torch.tensordot(E, T, ([0,2],[1,0])) # E(efga)T(gehb)=E(fahb)
E = torch.tensordot(E, P, ([0,2],[0,1])) # E(fahb)P(fhc)=E(abc)
# step 4: symmetrize C and E
C = 0.5*(C+C.t())
E = 0.5*(E + E.permute(2, 1, 0))
return C/C.norm(), E, S.abs()/S.abs().max(), truncation_error
def CTMRG(T, chi, max_iter, epsilon, use_checkpoint=False):
# T(up, left, down, right)
# C(down, right), E(up,right,down)
C = T.sum((0,1)) #
E = T.sum(1).permute(0,2,1)
truncation_error = 0.0
sold = torch.zeros(chi, dtype=T.dtype, device=T.device)
diff = 1E1
for n in range(max_iter):
tensors = C, E, T, torch.tensor(chi)
if use_checkpoint: # use checkpoint to save memory
C, E, s, error = checkpoint(renormalize, *tensors)
else:
C, E, s, error = renormalize(*tensors)
Enorm = E.norm()
E = E/Enorm
truncation_error += error.item()
if (s.numel() == sold.numel()):
diff = (s-sold).norm().item()
#print( s, sold )
#print( 'n: %d, Enorm: %g, error: %e, diff: %e' % (n, Enorm, error.item(), diff) )
if (diff < epsilon):
break
sold = s
#print ('---->ctmrg converged at iterations %d to %.5e, truncation error: %.5f'%(n, diff, truncation_error/n))
return C, E
if __name__=='__main__':
import time
torch.manual_seed(42)
D = 6
chi = 80
max_iter = 100
device = 'cpu'
# T(u,l,d,r)
T = torch.randn(D, D, D, D, dtype=torch.float64, device=device, requires_grad=True)
T = (T + T.permute(0, 3, 2, 1))/2. # left-right symmetry
T = (T + T.permute(2, 1, 0, 3))/2. # up-down symmetry
T = (T + T.permute(3, 2, 1, 0))/2. # skew-diagonal symmetry
T = (T + T.permute(1, 0, 3, 2))/2. # digonal symmetry
T = T/T.norm()
C, E = CTMRG(T, chi, max_iter, use_checkpoint=True)
C, E = CTMRG(T, chi, max_iter, use_checkpoint=False)
print( 'diffC = ', torch.dist( C, C.t() ) )
print( 'diffE = ', torch.dist( E, E.permute(2,1,0) ) )
import torch
from scipy.sparse.linalg import eigs
from scipy.sparse.linalg import LinearOperator
def dominant_eigensolver(Hopt, v0, tol):
N = v0.shape[0]
A = LinearOperator((N, N), matvec=lambda x: Hopt(torch.as_tensor(x)).detach().numpy())
vals, vecs = eigs(A, k=1, which='LM', v0=v0.detach().numpy(), tol=tol)
w = torch.as_tensor(vals[0].real, dtype=v0.dtype, device=v0.device)
v = torch.as_tensor(vecs[:, 0].real, dtype=v0.dtype, device=v0.device)
return w, v
import torch
def fixgauge_qr(A):
Q, R = torch.qr(A)
Rsgn = torch.diag(R).sign()
Q = Q * Rsgn
R = Rsgn[:, None] * R
return Q, R
def test_FixQR():
M, N = 10, 4
# torch.manual_seed(2)
A = torch.randn(M, N, dtype=torch.float64)
Q0, R0 = torch.qr(A)
Q1, R1 = fixgauge_qr(A)
print('R0 = ', R0)
print('R1 = ', R1)
print('qr0 = ', torch.dist(Q0 @ R0, A))
print('qr1 = ', torch.dist(Q1 @ R1, A))
if __name__ == "__main__":
test_FixQR()
import torch
from .fixgaugeqr import fixgauge_qr
from .dominanteigensolver import dominant_eigensolver
def left_orthonormalize(A, C, epsilon, maxiter=100):
'''
C*A = lambda*AL*C
'''
chi, d = A.shape[0], A.shape[1]
A = A.contiguous().view(chi, d * chi)
def step(C):
_, C = fixgauge_qr(C)
C = C / C.norm()
AC = (C @ A).view(chi * d, chi)
AL, Cnew = fixgauge_qr(AC)
lamb = Cnew.norm()
Cnew = Cnew / lamb
diff = torch.dist(Cnew, C)
return AL.view(chi, d, chi), Cnew, lamb, diff
AL, Cnew, lamb, diff = step(C)
for it in range(maxiter):