Advanced Computing Platform for Theoretical Physics

Commit a813c8cb authored by Lei Wang's avatar Lei Wang
Browse files

use tensornets; enable checkpoint

parent e3f4b90c
import torch
from tensornets import CTMRG
def ctmrg(T, chi, maxiter, epsilon, use_checkpoint=False):
C, E = CTMRG(T, chi, maxiter, epsilon, use_checkpoint=use_checkpoint)
Z1 = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C))
Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C))
Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C))
lnZ = torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs())
return lnZ
def symmetrize(A):
'''
A(phy, up, left, down, right)
left-right, up-down, diagonal symmetrize
'''
Asymm = (A + A.permute(0, 1, 4, 3, 2))/2. # left-right symmetry
Asymm = (Asymm + Asymm.permute(0, 3, 2, 1, 4))/2. # up-down symmetry
Asymm = (Asymm + Asymm.permute(0, 4, 3, 2, 1))/2. # skew-diagonal symmetry
Asymm = (Asymm + Asymm.permute(0, 2, 1, 4, 3))/2. # diagonal symmetry
return Asymm/Asymm.norm()
......@@ -8,49 +8,30 @@ import torch
torch.set_num_threads(1)
torch.manual_seed(42)
#from trg import levin_nave_trg as contraction
from ctmrg import ctmrg as contraction
#from vmps import vmps as contraction
def symmetrize(A):
As = A.view(d, D, D, D, D)
As = (As + As.permute(0, 4, 2, 3, 1))/2.
As = (As + As.permute(0, 1, 3, 2, 4))/2.
As = (As + As.permute(0, 3, 4, 1, 2))/2.
As = (As + As.permute(0, 2, 1, 4, 3))/2.
As = As.view(d, D**4)
return As/As.norm()
from contraction import symmetrize, ctmrg
if __name__=='__main__':
import argparse
parser = argparse.ArgumentParser(description='')
parser.add_argument("-D", type=int, default=2, help="D")
parser.add_argument("-Dcut", type=int, default=30, help="Dcut")
parser.add_argument("-Niter", type=int, default=50, help="Niter")
parser.add_argument("-Nepochs", type=int, default=10, help="Nepochs")
parser.add_argument("-chi", type=int, default=30, help="chi")
parser.add_argument("-maxiter", type=int, default=50, help="maxiter")
parser.add_argument("-epsilon", type=float, default=1E-10, help="maxiter")
parser.add_argument("-nepochs", type=int, default=10, help="nepochs")
parser.add_argument("-float32", action='store_true', help="use float32")
parser.add_argument("-lanczos_steps", type=int, default=0, help="lanczos steps")
parser.add_argument("-use_checkpoint", action='store_true', help="use checkpoint")
parser.add_argument("-cuda", type=int, default=-1, help="use GPU")
args = parser.parse_args()
device = torch.device("cpu" if args.cuda<0 else "cuda:"+str(args.cuda))
dtype = torch.float32 if args.float32 else torch.float64
if args.lanczos_steps>0: print ('lanczos steps', args.lanczos_steps)
d = 2 # fixed
D = args.D
Dcut = args.Dcut
Niter = args.Niter
B = 0.01* torch.randn(d, D, D, D, D, dtype=dtype, device=device)
#symmetrize initial boundary PEPS
A = torch.nn.Parameter(symmetrize(B).view(d, D**4))
#boundary MPS
#A1 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2*d, Dcut, dtype=dtype, device=device))
#A2 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2, Dcut, dtype=dtype, device=device))
A = torch.nn.Parameter(symmetrize(B))
#dimer covering
T = torch.zeros(d, d, d, d, d, d, dtype=dtype, device=device)
......@@ -66,23 +47,20 @@ if __name__=='__main__':
def closure():
optimizer.zero_grad()
As = symmetrize(A)
As = symmetrize(A).view(d, D**4)
T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d)
#double layer
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
#lnT = contraction(T1, D**2*d, Dcut, Niter, A1, lanczos_steps=args.lanczos_steps)
#lnZ = contraction(T2, D**2, Dcut, Niter, A2, lanczos_steps=args.lanczos_steps)
lnT, error1 = contraction(T1, D**2*d, Dcut, Niter)
lnZ, error2 = contraction(T2, D**2, Dcut, Niter)
lnT = ctmrg(T1, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
lnZ = ctmrg(T2, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
loss = (-lnT + lnZ)
#print (' total loss', loss.item())
#print (' loss, error', loss.item(), error1.item(), error2.item())
loss.backward()
return loss
for epoch in range(args.Nepochs):
for epoch in range(args.nepochs):
loss = optimizer.step(closure)
print ('epoch, loss, gradnorm:', epoch, loss.item(), A.grad.norm().item())
......@@ -90,13 +68,14 @@ if __name__=='__main__':
####################################################################3
def fun(x, info):
A = torch.as_tensor(x, device=device).requires_grad_()
As = symmetrize(A)
A = torch.as_tensor(x, device=device).view(d, D, D, D, D).requires_grad_()
As = symmetrize(A).view(d, D**4)
T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d)
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
lnT, error1 = contraction(T1, D**2*d, Dcut, Niter)
lnZ, error2 = contraction(T2, D**2, Dcut, Niter)
lnT = ctmrg(T1, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
lnZ = ctmrg(T2, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
loss = (-lnT + lnZ)
loss.backward(create_graph=True)
......@@ -112,30 +91,31 @@ if __name__=='__main__':
def hessp(x, p, info):
A = info['A']
if ((x == A.detach().cpu().numpy().ravel()).all()):
grad = A.grad.view(-1)
else:
if not ((x == A.detach().cpu().numpy().ravel()).all()):
print ('recalculate forward and grad')
A = torch.as_tensor(A, device=device).requires_grad_()
A = torch.as_tensor(x, device=device).view(d, D, D, D, D).requires_grad_()
As = symmetrize(A)
As = symmetrize(A).view(d, D**4)
T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d)
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
lnT, error1 = contraction(T1, D**2*d, Dcut, Niter)
lnZ, error2 = contraction(T2, D**2, Dcut, Niter)
lnT = ctmrg(T1, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
lnZ = ctmrg(T2, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
loss = (-lnT + lnZ)
loss.backward(create_graph=True)
grad = A.grad.view(-1)
Agrad = A.grad.data.clone() #take A.grad.data out
#dot it with the given vector
loss = torch.dot(grad, torch.as_tensor(p, device=device))
hvp = torch.autograd.grad(loss, A, retain_graph=True)[0].view(-1)
loss = A.grad.view(-1)@torch.as_tensor(p, device=device)
A.grad.data.zero_()
loss.backward(retain_graph=True)
res = A.grad.detach().cpu().numpy().ravel()
A.grad.data = Agrad #put it back
return res
return hvp.cpu().numpy().ravel()
import scipy.optimize
x0 = A.detach().cpu().numpy().ravel()
......
import torch
from .adlib import SVD
svd = SVD.apply
from .ctmrg import CTMRG
from .vumps import vumps_run, vumps_calc
def projection(T, epsilon=1E-3):
D = T.shape[0] # double layer bond dimension
M = T.view(D, -1)
M = M@M.t()
U, S, _ = svd(M)
#S = S/S.max()
#chi = (S>epsilon).sum().item()
#up to truncation error
chi = (torch.cumsum(S, dim=0)/S.sum() <= 1-epsilon).sum().item()
U = U[:, :chi].view(D, chi)
#print (S/S.max())
#print (torch.cumsum(S, dim=0)/S.sum() )
print ('---->truncated from', D, 'to', chi)
return torch.einsum('abcd,ai,bj,ck,dl->ijkl', (T, U, U, U, U)), U
def ctmrg(T, chi, maxiter, epsilon):
with torch.no_grad():
C, E = CTMRG(T, chi, maxiter, epsilon)
Z1 = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C))
Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C))
Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C))
lnZ = torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs())
return lnZ
def vumps(T, chi, maxiter, epsilon):
with torch.no_grad():
_, AC, F, C, _ = vumps_run(T, chi, epsilon, maxiter)
lnZ = torch.log(vumps_calc(T, AC, F, C))
return lnZ
from .ctmrg import CTMRG
from .singlelayer import projection, ctmrg, vumps
from .utils import save_checkpoint, load_checkpoint, kronecker_product, symmetrize
from .svd import SVD
from .eigh import EigenSolver
from .qr import QR
'''
PyTorch has its own implementation of backward function for symmetric eigensolver https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1660
However, it assumes a triangular adjoint.
We reimplement it to return a symmetric adjoint
'''
import numpy as np
import torch
class EigenSolver(torch.autograd.Function):
@staticmethod
def forward(self, A):
w, v = torch.symeig(A, eigenvectors=True)
self.save_for_backward(w, v)
return w, v
@staticmethod
def backward(self, dw, dv):
w, v = self.saved_tensors
dtype, device = w.dtype, w.device
N = v.shape[0]
F = w - w[:,None]
F.diagonal().fill_(np.inf)
# safe inverse
msk = (torch.abs(F) < 1e-20)
F[msk] += 1e-20
F = 1./F
vt = v.t()
vdv = vt@dv
return v@(torch.diag(dw) + F*(vdv-vdv.t())/2) @vt
import torch
from torch.utils.checkpoint import detach_variable
def step(A, x):
y = A@x
y = y[0].sign() * y
return y/y.norm()
class FixedPoint(torch.autograd.Function):
@staticmethod
def forward(ctx, A, x0, tol):
x, x_prev = step(A, x0), x0
while torch.dist(x, x_prev) > tol:
x, x_prev = step(A, x), x
ctx.save_for_backward(A, x)
ctx.tol = tol
return x
@staticmethod
def backward(ctx, grad):
A, x = detach_variable(ctx.saved_tensors)
dA = grad
while True:
with torch.enable_grad():
grad = torch.autograd.grad(step(A, x), x, grad_outputs=grad)[0]
if (torch.norm(grad) > ctx.tol):
dA = dA + grad
else:
break
with torch.enable_grad():
dA = torch.autograd.grad(step(A, x), A, grad_outputs=dA)[0]
return dA, None, None
import torch
class QR(torch.autograd.Function):
@staticmethod
def forward(self, A):
Q, R = torch.qr(A)
self.save_for_backward(A, Q, R)
return Q, R
@staticmethod
def backward(self, dq, dr):
A, q, r = self.saved_tensors
if r.shape[0] == r.shape[1]:
return _simple_qr_backward(q, r, dq ,dr)
M, N = r.shape
B = A[:,M:]
dU = dr[:,:M]
dD = dr[:,M:]
U = r[:,:M]
da = _simple_qr_backward(q, U, dq+B@dD.t(), dU)
db = q@dD
return torch.cat([da, db], 1)
def _simple_qr_backward(q, r, dq, dr):
if r.shape[-2] != r.shape[-1]:
raise NotImplementedError("QrGrad not implemented when ncols > nrows "
"or full_matrices is true and ncols != nrows.")
qdq = q.t() @ dq
qdq_ = qdq - qdq.t()
rdr = r @ dr.t()
rdr_ = rdr - rdr.t()
tril = torch.tril(qdq_ + rdr_)
def _TriangularSolve(x, r):
"""Equiv to x @ torch.inverse(r).t() if r is upper-tri."""
res = torch.triangular_solve(x.t(), r, upper=True, transpose=False)[0].t()
return res
grad_a = q @ (dr + _TriangularSolve(tril, r))
grad_b = _TriangularSolve(dq - q @ qdq, r)
return grad_a + grad_b
'''
PyTorch has its own implementation of backward function for SVD https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1577
We reimplement it with a safe inverse function in light of degenerated singular values
'''
import numpy as np
import torch
def safe_inverse(x, epsilon=1E-12):
return x/(x**2 + epsilon)
class SVD(torch.autograd.Function):
@staticmethod
def forward(self, A):
U, S, V = torch.svd(A)
self.save_for_backward(U, S, V)
return U, S, V
@staticmethod
def backward(self, dU, dS, dV):
U, S, V = self.saved_tensors
Vt = V.t()
Ut = U.t()
M = U.size(0)
N = V.size(0)
NS = len(S)
F = (S - S[:, None])
F = safe_inverse(F)
F.diagonal().fill_(0)
G = (S + S[:, None])
G.diagonal().fill_(np.inf)
G = 1/G
UdU = Ut @ dU
VdV = Vt @ dV
Su = (F+G)*(UdU-UdU.t())/2
Sv = (F-G)*(VdV-VdV.t())/2
dA = U @ (Su + Sv + torch.diag(dS)) @ Vt
if (M>NS):
dA = dA + (torch.eye(M, dtype=dU.dtype, device=dU.device) - U@Ut) @ (dU/S) @ Vt
if (N>NS):
dA = dA + (U/S) @ dV.t() @ (torch.eye(N, dtype=dU.dtype, device=dU.device) - V@Vt)
return dA
import torch
from torch.utils.checkpoint import checkpoint
from .adlib import SVD
svd = SVD.apply
#from .adlib import EigenSolver
#symeig = EigenSolver.apply
def renormalize(*tensors):
# T(up,left,down,right), u=up, l=left, d=down, r=right
# C(d,r), EL(u,r,d), EU(l,d,r)
C, E, T, chi = tensors
dimT, dimE = T.shape[0], E.shape[0]
D_new = min(dimE*dimT, chi)
# step 1: contruct the density matrix Rho
Rho = torch.tensordot(C,E,([1],[0])) # C(ef)*EU(fga)=Rho(ega)
Rho = torch.tensordot(Rho,E,([0],[0])) # Rho(ega)*EL(ehc)=Rho(gahc)
Rho = torch.tensordot(Rho,T,([0,2],[0,1])) # Rho(gahc)*T(ghdb)=Rho(acdb)
Rho = Rho.permute(0,3,1,2).contiguous().view(dimE*dimT, dimE*dimT) # Rho(acdb)->Rho(ab;cd)
Rho = Rho+Rho.t()
Rho = Rho/Rho.norm()
# step 2: Get Isometry P
U, S, V = torch.svd(Rho)
truncation_error = S[D_new:].sum()/S.sum()
P = U[:, :D_new] # projection operator
#can also do symeig since Rho is symmetric
#S, U = symeig(Rho)
#sorted, indices = torch.sort(S.abs(), descending=True)
#truncation_error = sorted[D_new:].sum()/sorted.sum()
#S = S[indices][:D_new]
#P = U[:, indices][:, :D_new] # projection operator
# step 3: renormalize C and E
C = (P.t() @ Rho @ P) #C(D_new, D_new)
## EL(u,r,d)
P = P.view(dimE,dimT,D_new)
E = torch.tensordot(E, P, ([0],[0])) # EL(def)P(dga)=E(efga)
E = torch.tensordot(E, T, ([0,2],[1,0])) # E(efga)T(gehb)=E(fahb)
E = torch.tensordot(E, P, ([0,2],[0,1])) # E(fahb)P(fhc)=E(abc)
# step 4: symmetrize C and E
C = 0.5*(C+C.t())
E = 0.5*(E + E.permute(2, 1, 0))
return C/C.norm(), E, S.abs()/S.abs().max(), truncation_error
def CTMRG(T, chi, max_iter, epsilon, use_checkpoint=False):
# T(up, left, down, right)
# C(down, right), E(up,right,down)
C = T.sum((0,1)) #
E = T.sum(1).permute(0,2,1)
truncation_error = 0.0
sold = torch.zeros(chi, dtype=T.dtype, device=T.device)
diff = 1E1
for n in range(max_iter):
tensors = C, E, T, torch.tensor(chi)
if use_checkpoint: # use checkpoint to save memory
C, E, s, error = checkpoint(renormalize, *tensors)
else:
C, E, s, error = renormalize(*tensors)
Enorm = E.norm()
E = E/Enorm
truncation_error += error.item()
if (s.numel() == sold.numel()):
diff = (s-sold).norm().item()
#print( s, sold )
#print( 'n: %d, Enorm: %g, error: %e, diff: %e' % (n, Enorm, error.item(), diff) )
if (diff < epsilon):
break
sold = s
#print ('---->ctmrg converged at iterations %d to %.5e, truncation error: %.5f'%(n, diff, truncation_error/n))
return C, E
if __name__=='__main__':
import time
torch.manual_seed(42)
D = 6
chi = 80
max_iter = 100
device = 'cpu'
# T(u,l,d,r)
T = torch.randn(D, D, D, D, dtype=torch.float64, device=device, requires_grad=True)
T = (T + T.permute(0, 3, 2, 1))/2. # left-right symmetry
T = (T + T.permute(2, 1, 0, 3))/2. # up-down symmetry
T = (T + T.permute(3, 2, 1, 0))/2. # skew-diagonal symmetry
T = (T + T.permute(1, 0, 3, 2))/2. # digonal symmetry
T = T/T.norm()
C, E = CTMRG(T, chi, max_iter, use_checkpoint=True)
C, E = CTMRG(T, chi, max_iter, use_checkpoint=False)
print( 'diffC = ', torch.dist( C, C.t() ) )
print( 'diffE = ', torch.dist( E, E.permute(2,1,0) ) )
import torch
from scipy.sparse.linalg import eigs
from scipy.sparse.linalg import LinearOperator
def dominant_eigensolver(Hopt, v0, tol):
N = v0.shape[0]
A = LinearOperator((N, N), matvec=lambda x: Hopt(torch.as_tensor(x)).detach().numpy())
vals, vecs = eigs(A, k=1, which='LM', v0=v0.detach().numpy(), tol=tol)
w = torch.as_tensor(vals[0].real, dtype=v0.dtype, device=v0.device)
v = torch.as_tensor(vecs[:, 0].real, dtype=v0.dtype, device=v0.device)
return w, v
import torch
def fixgauge_qr(A):
Q, R = torch.qr(A)
Rsgn = torch.diag(R).sign()
Q = Q * Rsgn
R = Rsgn[:, None] * R
return Q, R
def test_FixQR():
M, N = 10, 4
# torch.manual_seed(2)
A = torch.randn(M, N, dtype=torch.float64)
Q0, R0 = torch.qr(A)
Q1, R1 = fixgauge_qr(A)
print('R0 = ', R0)
print('R1 = ', R1)
print('qr0 = ', torch.dist(Q0 @ R0, A))
print('qr1 = ', torch.dist(Q1 @ R1, A))
if __name__ == "__main__":
test_FixQR()
import torch
from .fixgaugeqr import fixgauge_qr
from .dominanteigensolver import dominant_eigensolver
def left_orthonormalize(A, C, epsilon, maxiter=100):
'''
C*A = lambda*AL*C
'''
chi, d = A.shape[0], A.shape[1]
A = A.contiguous().view(chi, d * chi)
def step(C):
_, C = fixgauge_qr(C)
C = C / C.norm()
AC = (C @ A).view(chi * d, chi)
AL, Cnew = fixgauge_qr(AC)
lamb = Cnew.norm()
Cnew = Cnew / lamb
diff = torch.dist(Cnew, C)
return AL.view(chi, d, chi), Cnew, lamb, diff
AL, Cnew, lamb, diff = step(C)
for it in range(maxiter):
def Hopt(x):
'''
C = A^{T}_L C A
'''
C = x.view(chi, chi)
AC = (C @ A).view(chi * d, chi)
y = AL.view(chi * d, chi).t() @ AC
return y.view(-1)
_, C = dominant_eigensolver(Hopt, Cnew.contiguous().view(-1), diff / 10)
C = C.view(chi, chi)
AL, Cnew, lamb, diff = step(C)
if (diff < epsilon):
break
else:
C = Cnew
return AL, C, lamb
def right_orthonormalize(A, C, epsilon):
'''
A*C = lambda*C*AR
'''
AL, C, lamb = left_orthonormalize(A.permute(2, 1, 0), C.permute(1, 0), epsilon)
return AL.permute(2, 1, 0).contiguous(), C.permute(1, 0).contiguous(), lamb
if __name__ == '__main__':
torch.manual_seed(42)
chi = 30
d = 2
epsilon = 1E-12
dtype = torch.float64
A = torch.randn(chi, d, chi, dtype=dtype)
AL, C, lamb = left_orthonormalize(A, torch.eye(chi, dtype=dtype), epsilon)
print(
torch.dist((C @ A.view(chi, d * chi)).view(chi, d, chi), lamb * (AL.view(chi * d, chi) @ C).view(chi, d, chi)))
AL = AL.view(chi * d, chi)
print(torch.dist(AL.t() @ AL, torch.eye(chi, dtype=dtype)))
# AR, C, lamb = right_orthonormalize(A, torch.eye(chi, dtype=dtype), epsilon)
# print (torch.dist((A.view(chi*d, chi)@C).view(-1), lamb*(C@AR.view(chi, d*chi)).view(-1)))
# AR = AR.view(chi, d*chi)
# print (AR@AR.t())
import torch
from .adlib import SVD
svd = SVD.apply
from .ctmrg import CTMRG
from .vumps import vumps_run, vumps_calc
def projection(T, epsilon=1E-3):
D = T.shape[0] # double layer bond dimension
M = T.view(D, -1)
M = M@M.t()
U, S, _ = svd(M)
#S = S/S.max()
#chi = (S>epsilon).sum().item()
#up to truncation error
chi = (torch.cumsum(S, dim=0)/S.sum() <= 1-epsilon).sum().item()
U = U[:, :chi].view(D, chi)
#print (S/S.max())
#print (torch.cumsum(S, dim=0)/S.sum() )
print ('---->truncated from', D, 'to', chi)
return torch.einsum('abcd,ai,bj,ck,dl->ijkl', (T, U, U, U, U)), U
def ctmrg(T, chi, maxiter, epsilon):
with torch.no_grad():
C, E = CTMRG(T, chi, maxiter, epsilon)
Z1 = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C))
Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C))
Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C))
lnZ = torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs())
return lnZ
def vumps(T, chi, maxiter, epsilon):
with torch.no_grad():
_, AC, F, C, _ = vumps_run(T, chi, epsilon, maxiter)
lnZ = torch.log(vumps_calc(T, AC, F, C))
return lnZ
import numpy as np
import torch
import sys, os
testdir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(testdir+"/..")
from adlib.eigh import EigenSolver
def test_eigs():
M = 2
torch.manual_seed(42)
A = torch.rand(M, M, dtype=torch.float64)
A = torch.nn.Parameter(A+A.t())
assert(torch.autograd.gradcheck(EigenSolver.apply, A, eps=1e-6, atol=1e-4))
import numpy as np
import torch
import sys, os
testdir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(testdir+"/..")
from adlib.power import FixedPoint
def test_backward():
N = 4
torch.manual_seed(2)
A = torch.rand(N, N, dtype=torch.float64, requires_grad=True)
x0 = torch.rand(N, dtype=torch.float64)
x0 = x0/x0.norm()
tol = 1E-10
input = A, x0, tol
assert(torch.autograd.gradcheck(FixedPoint.apply, input, eps=1E-6, atol=tol))
def test_forward():
torch.manual_seed(42)
N = 100
tol = 1E-8
dtype = torch.float64
A = torch.randn(N, N, dtype=dtype)
A = A+A.t()
w, v = torch.symeig(A, eigenvectors=True)
idx = torch.argmax(w.abs())
v_exact = v[:, idx]
v_exact = v_exact[0].sign() * v_exact
x0 = torch.rand(N, dtype=dtype)
x0 = x0/x0.norm()
x = FixedPoint.apply(A, x0, tol)
assert(torch.allclose(v_exact, x, rtol=tol, atol=tol))
import numpy as np
import torch
import sys, os
testdir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(testdir+"/..")
from adlib.qr import QR
def test_qr():
M, N = 4, 6
torch.manual_seed(2)
A = torch.randn(M, N, dtype=torch.float64)
A.requires_grad=True
assert(torch.autograd.gradcheck(QR.apply, A, eps=1e-4, atol=1e-2))
import numpy as np
import torch
import sys, os
testdir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(testdir+"/..")
from adlib.svd import SVD
def test_svd():
M, N = 20, 16
torch.manual_seed(2)
input = torch.rand(M, N, dtype=torch.float64, requires_grad=True)
assert(torch.autograd.gradcheck(SVD.apply, input, eps=1e-6, atol=1e-4))
import torch
import collections
""" A module for the function ncon, which does contractions of several tensors.
"""
def ncon(AA, v, order=None, forder=None, check_indices=True):
""" AA = [A1, A2, ..., Ap] list of tensors.
v = (v1, v2, ..., vp) tuple of lists of indices e.g. v1 = [3 4 -1] labels
the three indices of tensor A1, with -1 indicating an uncontracted index
(open leg) and 3 and 4 being the contracted indices.
order, if present, contains a list of all positive indices - if not
[1 2 3 4 ...] by default. This is the order in which they are contracted.
forder, if present, contains the final ordering of the uncontracted indices
- if not, [-1 -2 ..] by default.
There is some leeway in the way the inputs are given. For example,
instead of giving a list of tensors as the first argument one can
give some different iterable of tensors, such as a tuple, or a
single tensor by itself (anything that has the attribute "shape"
will be considered a tensor).
"""
# We want to handle the tensors as a list, regardless of what kind
# of iterable we are given. In addition, if only a single element is
# given, we make list out of it. Inputs are assumed to be non-empty.
if hasattr(AA, "shape"):
AA = [AA]
else:
AA = list(AA)
v = list(v)
if not isinstance(v[0], collections.Iterable):
# v is not a list of lists, so make it such.
v = [v]
else:
v = list(map(list, v))
if order == None:
order = create_order(v)
if forder == None:
forder = create_forder(v)
if check_indices:
# Raise a RuntimeError if the indices are wrong.
do_check_indices(AA, v, order, forder)
# If the graph is dinconnected, connect it with trivial indices that
# will be contracted at the very end.
connect_graph(AA, v, order)
while len(order) > 0:
tcon = get_tcon(v, order[0]) # tcon = tensors to be contracted
# Find the indices icon that are to be contracted.
if len(tcon) == 1:
tracing = True
icon = [order[0]]
else:
tracing = False
icon = get_icon(v, tcon)
# Position in tcon[0] and tcon[1] of indices to be contracted.
# In the case of trace, pos2 = []
pos1, pos2 = get_pos(v, tcon, icon)
if tracing:
# Trace on a tensor
new_A = trace(AA[tcon[0]], axis1=pos1[0], axis2=pos1[1])
else:
# Contraction of 2 tensors
new_A = con(AA[tcon[0]], AA[tcon[1]], (pos1, pos2))
AA.append(new_A)
v.append(find_newv(v, tcon, icon)) # Add the v for the new tensor
for i in sorted(tcon, reverse=True):
# Delete the contracted tensors and indices from the lists.
# tcon is reverse sorted so that tensors are removed starting from
# the end of AA, otherwise the order would get messed.
del AA[i]
del v[i]
order = renew_order(order, icon) # Update order
vlast = v[0]
A = AA[0]
A = permute_final(A, vlast, forder)
return A
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
def create_order(v):
""" Identify all unique, positive indices and return them sorted. """
flat_v = sum(v, [])
x = [i for i in flat_v if i > 0]
# Converting to a set and back removes duplicates
x = list(set(x))
return sorted(x)
def create_forder(v):
""" Identify all unique, negative indices and return them reverse sorted
(-1 first).
"""
flat_v = sum(v, [])
x = [i for i in flat_v if i < 0]
# Converting to a set and back removes duplicates
x = list(set(x))
return sorted(x, reverse=True)
def connect_graph(AA, v, order):
""" Connect the graph of tensors to be contracted by trivial
indices, if necessary. Add these trivial indices to the end of the
contraction order.
AA, v and order are modified in place.
"""
# Build ccomponents, a list of the connected components of the graph,
# where each component is represented by a a set of indices.
unvisited = set(range(len(AA)))
visited = set()
ccomponents = []
while unvisited:
component = set()
next_visit = unvisited.pop()
to_visit = {next_visit}
while to_visit:
i = to_visit.pop()
unvisited.discard(i)
component.add(i)
visited.add(i)
# Get the indices of tensors neighbouring AA[i].
i_inds = set(v[i])
neighs = (j for j, j_inds in enumerate(v) if i_inds.intersection(j_inds))
for neigh in neighs:
if neigh not in visited:
to_visit.add(neigh)
ccomponents.append(component)
# If there is more than one connected component, take one of them, a
# take an arbitrary tensor (called c) out of it, and connect that
# tensor with an arbitrary tensor (called d) from all the other
# components using a trivial index.
c = ccomponents.pop().pop()
while ccomponents:
d = ccomponents.pop().pop()
A_c = AA[c]
A_d = AA[d]
c_axis = len(v[c])
d_axis = len(v[d])
try:
AA[c] = A_c.expand_dims(c_axis, direction=1)
except AttributeError:
AA[c] = np.expand_dims(A_c, c_axis)
try:
AA[d] = A_d.expand_dims(d_axis, direction=-1)
except AttributeError:
AA[d] = np.expand_dims(A_d, d_axis)
try:
dim_num = max(order) + 1
except ValueError:
dim_num = 1
v[c].append(dim_num)
v[d].append(dim_num)
order.append(dim_num)
return None
def get_tcon(v, index):
""" Gets the list indices in AA of the tensors that have index as their
leg.
"""
tcon = []
for i, inds in enumerate(v):
if index in inds:
tcon.append(i)
l = len(tcon)
# If check_indices is called and it does its work properly then these
# checks should in fact be unnecessary.
if l > 2:
raise ValueError('In ncon.get_tcon, more than two tensors share a '
'contraction index.')
elif l < 1:
raise ValueError('In ncon.get_tcon, less than one tensor share a '
'contraction index.')
elif l == 1:
# The contraction is a trace.
how_many = v[tcon[0]].count(index)
if how_many != 2:
# Only one tensor has this index but it is not a trace because it
# does not occur twice for that tensor.
raise ValueError('In ncon.get_tcon, a trace index is listed '
'!= 2 times for the same tensor.')
return tcon
def get_icon(v, tcon):
""" Returns a list of indices that are to be contracted when contractions
between the two tensors numbered in tcon are contracted. """
inds1 = v[tcon[0]]
inds2 = v[tcon[1]]
icon = set(inds1).intersection(inds2)
icon = list(icon)
return icon
def get_pos(v, tcon, icon):
""" Get the positions of the indices icon in the list of legs the tensors
tcon to be contracted.
"""
pos1 = [[i for i, x in enumerate(v[tcon[0]]) if x == e] for e in icon]
pos1 = sum(pos1, [])
if len(tcon) < 2:
pos2 = []
else:
pos2 = [[i for i, x in enumerate(v[tcon[1]]) if x == e] for e in icon]
pos2 = sum(pos2, [])
return pos1, pos2
def find_newv(v, tcon, icon):
""" Find the list of indices for the new tensor after contraction of
indices icon of the tensors tcon.
"""
if len(tcon) == 2:
newv = v[tcon[0]] + v[tcon[1]]
else:
newv = v[tcon[0]]
newv = [i for i in newv if i not in icon]
return newv
def renew_order(order, icon):
""" Returns the new order with the contracted indices removed from it. """
return [i for i in order if i not in icon]
def permute_final(A, v, forder):
""" Returns the final tensor A with its legs permuted to the order given
in forder.
"""
perm = [v.index(i) for i in forder]
return A.permute(tuple(perm))
def do_check_indices(AA, v, order, forder):
""" Check that
1) the number of tensors in AA matches the number of index lists in v.
2) every tensor is given the right number of indices.
3) every contracted index is featured exactly twice and every free index
exactly once.
4) the dimensions of the two ends of each contracted index match.
"""
# 1)
if len(AA) != len(v):
raise ValueError(('In ncon.do_check_indices, the number of tensors %i'
' does not match the number of index lists %i')
% (len(AA), len(v)))
# 2)
# Create a list of lists with the shapes of each A in AA.
shapes = list(map(lambda A: list(A.shape), AA))
for i, inds in enumerate(v):
if len(inds) != len(shapes[i]):
raise ValueError(('In ncon.do_check_indices, len(v[%i])=%i '
'does not match the numbers of indices of '
'AA[%i] = %i') % (i, len(inds), i,
len(shapes[i])))
# 3) and 4)
# v_pairs = [[(0,0), (0,1), (0,2), ...], [(1,0), (1,1), (1,2), ...], ...]
v_pairs = [[(i, j) for j in range(len(s))] for i, s in enumerate(v)]
v_pairs = sum(v_pairs, [])
v_sum = sum(v, [])
# For t, o in zip(v_pairs, v_sum) t is the tuple of the number of
# the tensor and the index and o is the contraction order of that
# index. We group these tuples by the contraction order.
order_groups = [[t for t, o in zip(v_pairs, v_sum) if o == e]
for e in order]
forder_groups = [[1 for fo in v_sum if fo == e] for e in forder]
for i, o in enumerate(order_groups):
if len(o) != 2:
raise ValueError(('In ncon.do_check_indices, the contracted index '
'%i is not featured exactly twice in v.') % order[i])
else:
A0, ind0 = o[0]
A1, ind1 = o[1]
try:
compatible = AA[A0].compatible_indices(AA[A1], ind0, ind1)
except AttributeError:
compatible = AA[A0].shape[ind0] == AA[A1].shape[ind1]
if not compatible:
raise ValueError('In ncon.do_check_indices, for the '
'contraction index %i, the leg %i of tensor '
'number %i and the leg %i of tensor number '
'%i are not compatible.'
% (order[i], ind0, A0, ind1, A1))
for i, fo in enumerate(forder_groups):
if len(fo) != 1:
raise ValueError(('In ncon.do_check_indices, the free index '
'%i is not featured exactly once in v.') % forder[i])
# All is well if we made it here.
return True
####################################################################
# The following are simple wrappers around numpy/Tensor functions, #
# but may be replaced with fancier stuff later. #
####################################################################
def con(A, B, inds):
if torch.is_tensor(A) and torch.is_tensor(B):
return torch.tensordot(A, B, inds)
else:
return A.dot(B, inds)
def trace(A, axis1=0, axis2=1):
return A.trace(axis1=axis1, axis2=axis2)
if __name__ == '__main__':
A = torch.randn(2, 3, 4, 5, dtype=torch.float64)
B = torch.randn(4, 5, 2, dtype=torch.float64)
C_ncon = ncon([A, B], ([-1, -2, 1, 2], [1, 2, -3]), [1, 2], [-3, -1, -2])
C_einsum = torch.einsum('abcd,cde->eab', (A, B))
print((C_ncon - C_einsum).abs().sum())
import torch
from .adlib import SVD
svd = SVD.apply
def renormalize(*args):
T, chi, epsilon = args
D = T.shape[0]
Ma = T.view(D**2, D**2)
Mb = T.permute(1, 2, 0, 3).contiguous().view(D**2, D**2)
Ua, Sa, Va = svd(Ma)
Ub, Sb, Vb = svd(Mb)
D_new = min(min(D**2, chi), min((Sa>epsilon).sum().item(), (Sb>epsilon).sum().item()))
S1 = (Ua[:, :D_new]* torch.sqrt(Sa[:D_new])).view(D, D, D_new)
S3 = (Va[:, :D_new]* torch.sqrt(Sa[:D_new])).view(D, D, D_new)
S2 = (Ub[:, :D_new]* torch.sqrt(Sb[:D_new])).view(D, D, D_new)
S4 = (Vb[:, :D_new]* torch.sqrt(Sb[:D_new])).view(D, D, D_new)
return torch.einsum('xwu,yxl,yzd,wzr->uldr', (S2, S3, S4, S1))
def TRG(T, chi, no_iter, epsilon=1E-15):
lnZ = 0.0
for n in range(no_iter):
maxval = T.abs().max()
T = T/maxval
lnZ += 2**(no_iter-n)*torch.log(maxval)
T = renormalize(T, chi, epsilon)
trace = 0.0
for x in range(T.shape[0]):
for y in range(T.shape[1]):
trace += T[x, y, x, y]
lnZ += torch.log(trace)
return lnZ/2**no_iter
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment