From a813c8cb77f3c35831ff41bbce35fbfec80523fe Mon Sep 17 00:00:00 2001 From: Lei Wang Date: Mon, 17 Feb 2020 21:06:09 +0800 Subject: [PATCH] use tensornets; enable checkpoint --- contraction.py | 26 +++ dimer_covering.py | 82 +++----- singlelayer.py | 48 +++++ tensornets/__init__.py | 3 + tensornets/adlib/__init__.py | 3 + tensornets/adlib/eigh.py | 35 ++++ tensornets/adlib/power.py | 33 +++ tensornets/adlib/qr.py | 43 ++++ tensornets/adlib/svd.py | 48 +++++ tensornets/ctmrg.py | 108 ++++++++++ tensornets/dominanteigensolver.py | 13 ++ tensornets/fixgaugeqr.py | 28 +++ tensornets/orthonormalize.py | 71 +++++++ tensornets/singlelayer.py | 48 +++++ tensornets/tests/.test_eigh.py | 14 ++ tensornets/tests/test_power.py | 38 ++++ tensornets/tests/test_qr.py | 14 ++ tensornets/tests/test_svd.py | 13 ++ tensornets/torchncon.py | 330 ++++++++++++++++++++++++++++++ tensornets/trg.py | 41 ++++ tensornets/utils.py | 63 ++++++ tensornets/vumps.py | 104 ++++++++++ 22 files changed, 1155 insertions(+), 51 deletions(-) create mode 100644 contraction.py create mode 100644 singlelayer.py create mode 100644 tensornets/__init__.py create mode 100644 tensornets/adlib/__init__.py create mode 100644 tensornets/adlib/eigh.py create mode 100644 tensornets/adlib/power.py create mode 100644 tensornets/adlib/qr.py create mode 100644 tensornets/adlib/svd.py create mode 100644 tensornets/ctmrg.py create mode 100644 tensornets/dominanteigensolver.py create mode 100644 tensornets/fixgaugeqr.py create mode 100644 tensornets/orthonormalize.py create mode 100644 tensornets/singlelayer.py create mode 100644 tensornets/tests/.test_eigh.py create mode 100644 tensornets/tests/test_power.py create mode 100644 tensornets/tests/test_qr.py create mode 100644 tensornets/tests/test_svd.py create mode 100644 tensornets/torchncon.py create mode 100644 tensornets/trg.py create mode 100644 tensornets/utils.py create mode 100644 tensornets/vumps.py diff --git a/contraction.py b/contraction.py new file mode 100644 index 0000000..87d8d4d --- /dev/null +++ b/contraction.py @@ -0,0 +1,26 @@ +import torch +from tensornets import CTMRG + +def ctmrg(T, chi, maxiter, epsilon, use_checkpoint=False): + + C, E = CTMRG(T, chi, maxiter, epsilon, use_checkpoint=use_checkpoint) + + Z1 = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C)) + Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C)) + Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C)) + lnZ = torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs()) + + return lnZ + +def symmetrize(A): + ''' + A(phy, up, left, down, right) + left-right, up-down, diagonal symmetrize + ''' + Asymm = (A + A.permute(0, 1, 4, 3, 2))/2. # left-right symmetry + Asymm = (Asymm + Asymm.permute(0, 3, 2, 1, 4))/2. # up-down symmetry + Asymm = (Asymm + Asymm.permute(0, 4, 3, 2, 1))/2. # skew-diagonal symmetry + Asymm = (Asymm + Asymm.permute(0, 2, 1, 4, 3))/2. # diagonal symmetry + + return Asymm/Asymm.norm() + diff --git a/dimer_covering.py b/dimer_covering.py index 7217566..fbb462a 100644 --- a/dimer_covering.py +++ b/dimer_covering.py @@ -8,50 +8,31 @@ import torch torch.set_num_threads(1) torch.manual_seed(42) -#from trg import levin_nave_trg as contraction -from ctmrg import ctmrg as contraction -#from vmps import vmps as contraction - -def symmetrize(A): - As = A.view(d, D, D, D, D) - As = (As + As.permute(0, 4, 2, 3, 1))/2. - As = (As + As.permute(0, 1, 3, 2, 4))/2. - As = (As + As.permute(0, 3, 4, 1, 2))/2. - As = (As + As.permute(0, 2, 1, 4, 3))/2. - As = As.view(d, D**4) - return As/As.norm() +from contraction import symmetrize, ctmrg if __name__=='__main__': import argparse parser = argparse.ArgumentParser(description='') parser.add_argument("-D", type=int, default=2, help="D") - parser.add_argument("-Dcut", type=int, default=30, help="Dcut") - parser.add_argument("-Niter", type=int, default=50, help="Niter") - parser.add_argument("-Nepochs", type=int, default=10, help="Nepochs") + parser.add_argument("-chi", type=int, default=30, help="chi") + parser.add_argument("-maxiter", type=int, default=50, help="maxiter") + parser.add_argument("-epsilon", type=float, default=1E-10, help="maxiter") + parser.add_argument("-nepochs", type=int, default=10, help="nepochs") parser.add_argument("-float32", action='store_true', help="use float32") - parser.add_argument("-lanczos_steps", type=int, default=0, help="lanczos steps") + parser.add_argument("-use_checkpoint", action='store_true', help="use checkpoint") parser.add_argument("-cuda", type=int, default=-1, help="use GPU") args = parser.parse_args() device = torch.device("cpu" if args.cuda<0 else "cuda:"+str(args.cuda)) dtype = torch.float32 if args.float32 else torch.float64 - if args.lanczos_steps>0: print ('lanczos steps', args.lanczos_steps) - d = 2 # fixed - D = args.D - Dcut = args.Dcut - Niter = args.Niter B = 0.01* torch.randn(d, D, D, D, D, dtype=dtype, device=device) #symmetrize initial boundary PEPS - A = torch.nn.Parameter(symmetrize(B).view(d, D**4)) + A = torch.nn.Parameter(symmetrize(B)) - #boundary MPS - #A1 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2*d, Dcut, dtype=dtype, device=device)) - #A2 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2, Dcut, dtype=dtype, device=device)) - #dimer covering T = torch.zeros(d, d, d, d, d, d, dtype=dtype, device=device) T[0, 0, 0, 0, 0, 1] = 1.0 @@ -66,23 +47,20 @@ if __name__=='__main__': def closure(): optimizer.zero_grad() - As = symmetrize(A) + As = symmetrize(A).view(d, D**4) T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d) #double layer T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2) - #lnT = contraction(T1, D**2*d, Dcut, Niter, A1, lanczos_steps=args.lanczos_steps) - #lnZ = contraction(T2, D**2, Dcut, Niter, A2, lanczos_steps=args.lanczos_steps) - lnT, error1 = contraction(T1, D**2*d, Dcut, Niter) - lnZ, error2 = contraction(T2, D**2, Dcut, Niter) + + lnT = ctmrg(T1, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint) + lnZ = ctmrg(T2, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint) loss = (-lnT + lnZ) - #print (' total loss', loss.item()) - #print (' loss, error', loss.item(), error1.item(), error2.item()) loss.backward() return loss - for epoch in range(args.Nepochs): + for epoch in range(args.nepochs): loss = optimizer.step(closure) print ('epoch, loss, gradnorm:', epoch, loss.item(), A.grad.norm().item()) @@ -90,13 +68,14 @@ if __name__=='__main__': ####################################################################3 def fun(x, info): - A = torch.as_tensor(x, device=device).requires_grad_() - As = symmetrize(A) + A = torch.as_tensor(x, device=device).view(d, D, D, D, D).requires_grad_() + As = symmetrize(A).view(d, D**4) T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d) T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2) - lnT, error1 = contraction(T1, D**2*d, Dcut, Niter) - lnZ, error2 = contraction(T2, D**2, Dcut, Niter) + + lnT = ctmrg(T1, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint) + lnZ = ctmrg(T2, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint) loss = (-lnT + lnZ) loss.backward(create_graph=True) @@ -112,30 +91,31 @@ if __name__=='__main__': def hessp(x, p, info): A = info['A'] - - if ((x == A.detach().cpu().numpy().ravel()).all()): - grad = A.grad.view(-1) - else: + if not ((x == A.detach().cpu().numpy().ravel()).all()): print ('recalculate forward and grad') - A = torch.as_tensor(A, device=device).requires_grad_() + A = torch.as_tensor(x, device=device).view(d, D, D, D, D).requires_grad_() - As = symmetrize(A) + As = symmetrize(A).view(d, D**4) T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d) T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2) - lnT, error1 = contraction(T1, D**2*d, Dcut, Niter) - lnZ, error2 = contraction(T2, D**2, Dcut, Niter) + lnT = ctmrg(T1, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint) + lnZ = ctmrg(T2, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint) loss = (-lnT + lnZ) loss.backward(create_graph=True) - - grad = A.grad.view(-1) + Agrad = A.grad.data.clone() #take A.grad.data out + #dot it with the given vector - loss = torch.dot(grad, torch.as_tensor(p, device=device)) - hvp = torch.autograd.grad(loss, A, retain_graph=True)[0].view(-1) + loss = A.grad.view(-1)@torch.as_tensor(p, device=device) + A.grad.data.zero_() + loss.backward(retain_graph=True) + + res = A.grad.detach().cpu().numpy().ravel() + A.grad.data = Agrad #put it back + return res - return hvp.cpu().numpy().ravel() import scipy.optimize x0 = A.detach().cpu().numpy().ravel() diff --git a/singlelayer.py b/singlelayer.py new file mode 100644 index 0000000..a2eb556 --- /dev/null +++ b/singlelayer.py @@ -0,0 +1,48 @@ +import torch + +from .adlib import SVD +svd = SVD.apply +from .ctmrg import CTMRG +from .vumps import vumps_run, vumps_calc + +def projection(T, epsilon=1E-3): + + D = T.shape[0] # double layer bond dimension + + M = T.view(D, -1) + M = M@M.t() + + U, S, _ = svd(M) + + #S = S/S.max() + #chi = (S>epsilon).sum().item() + + #up to truncation error + chi = (torch.cumsum(S, dim=0)/S.sum() <= 1-epsilon).sum().item() + + U = U[:, :chi].view(D, chi) + #print (S/S.max()) + #print (torch.cumsum(S, dim=0)/S.sum() ) + print ('---->truncated from', D, 'to', chi) + + return torch.einsum('abcd,ai,bj,ck,dl->ijkl', (T, U, U, U, U)), U + +def ctmrg(T, chi, maxiter, epsilon): + + with torch.no_grad(): + C, E = CTMRG(T, chi, maxiter, epsilon) + + Z1 = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C)) + Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C)) + Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C)) + lnZ = torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs()) + + return lnZ + +def vumps(T, chi, maxiter, epsilon): + with torch.no_grad(): + _, AC, F, C, _ = vumps_run(T, chi, epsilon, maxiter) + + lnZ = torch.log(vumps_calc(T, AC, F, C)) + return lnZ + diff --git a/tensornets/__init__.py b/tensornets/__init__.py new file mode 100644 index 0000000..a9e68f9 --- /dev/null +++ b/tensornets/__init__.py @@ -0,0 +1,3 @@ +from .ctmrg import CTMRG +from .singlelayer import projection, ctmrg, vumps +from .utils import save_checkpoint, load_checkpoint, kronecker_product, symmetrize diff --git a/tensornets/adlib/__init__.py b/tensornets/adlib/__init__.py new file mode 100644 index 0000000..985198f --- /dev/null +++ b/tensornets/adlib/__init__.py @@ -0,0 +1,3 @@ +from .svd import SVD +from .eigh import EigenSolver +from .qr import QR diff --git a/tensornets/adlib/eigh.py b/tensornets/adlib/eigh.py new file mode 100644 index 0000000..fcad0d8 --- /dev/null +++ b/tensornets/adlib/eigh.py @@ -0,0 +1,35 @@ +''' +PyTorch has its own implementation of backward function for symmetric eigensolver https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1660 +However, it assumes a triangular adjoint. +We reimplement it to return a symmetric adjoint +''' + +import numpy as np +import torch + +class EigenSolver(torch.autograd.Function): + @staticmethod + def forward(self, A): + w, v = torch.symeig(A, eigenvectors=True) + + self.save_for_backward(w, v) + return w, v + + @staticmethod + def backward(self, dw, dv): + w, v = self.saved_tensors + dtype, device = w.dtype, w.device + N = v.shape[0] + + F = w - w[:,None] + F.diagonal().fill_(np.inf) + # safe inverse + msk = (torch.abs(F) < 1e-20) + F[msk] += 1e-20 + F = 1./F + + vt = v.t() + vdv = vt@dv + + return v@(torch.diag(dw) + F*(vdv-vdv.t())/2) @vt + diff --git a/tensornets/adlib/power.py b/tensornets/adlib/power.py new file mode 100644 index 0000000..155f089 --- /dev/null +++ b/tensornets/adlib/power.py @@ -0,0 +1,33 @@ +import torch +from torch.utils.checkpoint import detach_variable + +def step(A, x): + y = A@x + y = y[0].sign() * y + return y/y.norm() + +class FixedPoint(torch.autograd.Function): + @staticmethod + def forward(ctx, A, x0, tol): + x, x_prev = step(A, x0), x0 + while torch.dist(x, x_prev) > tol: + x, x_prev = step(A, x), x + ctx.save_for_backward(A, x) + ctx.tol = tol + return x + + @staticmethod + def backward(ctx, grad): + A, x = detach_variable(ctx.saved_tensors) + dA = grad + while True: + with torch.enable_grad(): + grad = torch.autograd.grad(step(A, x), x, grad_outputs=grad)[0] + if (torch.norm(grad) > ctx.tol): + dA = dA + grad + else: + break + with torch.enable_grad(): + dA = torch.autograd.grad(step(A, x), A, grad_outputs=dA)[0] + return dA, None, None + diff --git a/tensornets/adlib/qr.py b/tensornets/adlib/qr.py new file mode 100644 index 0000000..a2f8cdb --- /dev/null +++ b/tensornets/adlib/qr.py @@ -0,0 +1,43 @@ +import torch + +class QR(torch.autograd.Function): + @staticmethod + def forward(self, A): + Q, R = torch.qr(A) + self.save_for_backward(A, Q, R) + return Q, R + + @staticmethod + def backward(self, dq, dr): + A, q, r = self.saved_tensors + if r.shape[0] == r.shape[1]: + return _simple_qr_backward(q, r, dq ,dr) + M, N = r.shape + B = A[:,M:] + dU = dr[:,:M] + dD = dr[:,M:] + U = r[:,:M] + da = _simple_qr_backward(q, U, dq+B@dD.t(), dU) + db = q@dD + return torch.cat([da, db], 1) + +def _simple_qr_backward(q, r, dq, dr): + if r.shape[-2] != r.shape[-1]: + raise NotImplementedError("QrGrad not implemented when ncols > nrows " + "or full_matrices is true and ncols != nrows.") + + qdq = q.t() @ dq + qdq_ = qdq - qdq.t() + rdr = r @ dr.t() + rdr_ = rdr - rdr.t() + tril = torch.tril(qdq_ + rdr_) + + def _TriangularSolve(x, r): + """Equiv to x @ torch.inverse(r).t() if r is upper-tri.""" + res = torch.triangular_solve(x.t(), r, upper=True, transpose=False)[0].t() + return res + + grad_a = q @ (dr + _TriangularSolve(tril, r)) + grad_b = _TriangularSolve(dq - q @ qdq, r) + return grad_a + grad_b + diff --git a/tensornets/adlib/svd.py b/tensornets/adlib/svd.py new file mode 100644 index 0000000..1a8d903 --- /dev/null +++ b/tensornets/adlib/svd.py @@ -0,0 +1,48 @@ +''' +PyTorch has its own implementation of backward function for SVD https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1577 +We reimplement it with a safe inverse function in light of degenerated singular values +''' + +import numpy as np +import torch + +def safe_inverse(x, epsilon=1E-12): + return x/(x**2 + epsilon) + +class SVD(torch.autograd.Function): + @staticmethod + def forward(self, A): + U, S, V = torch.svd(A) + self.save_for_backward(U, S, V) + return U, S, V + + @staticmethod + def backward(self, dU, dS, dV): + U, S, V = self.saved_tensors + Vt = V.t() + Ut = U.t() + M = U.size(0) + N = V.size(0) + NS = len(S) + + F = (S - S[:, None]) + F = safe_inverse(F) + F.diagonal().fill_(0) + + G = (S + S[:, None]) + G.diagonal().fill_(np.inf) + G = 1/G + + UdU = Ut @ dU + VdV = Vt @ dV + + Su = (F+G)*(UdU-UdU.t())/2 + Sv = (F-G)*(VdV-VdV.t())/2 + + dA = U @ (Su + Sv + torch.diag(dS)) @ Vt + if (M>NS): + dA = dA + (torch.eye(M, dtype=dU.dtype, device=dU.device) - U@Ut) @ (dU/S) @ Vt + if (N>NS): + dA = dA + (U/S) @ dV.t() @ (torch.eye(N, dtype=dU.dtype, device=dU.device) - V@Vt) + return dA + diff --git a/tensornets/ctmrg.py b/tensornets/ctmrg.py new file mode 100644 index 0000000..13b460b --- /dev/null +++ b/tensornets/ctmrg.py @@ -0,0 +1,108 @@ +import torch +from torch.utils.checkpoint import checkpoint + +from .adlib import SVD +svd = SVD.apply +#from .adlib import EigenSolver +#symeig = EigenSolver.apply + +def renormalize(*tensors): + # T(up,left,down,right), u=up, l=left, d=down, r=right + # C(d,r), EL(u,r,d), EU(l,d,r) + + C, E, T, chi = tensors + + dimT, dimE = T.shape[0], E.shape[0] + D_new = min(dimE*dimT, chi) + + # step 1: contruct the density matrix Rho + Rho = torch.tensordot(C,E,([1],[0])) # C(ef)*EU(fga)=Rho(ega) + Rho = torch.tensordot(Rho,E,([0],[0])) # Rho(ega)*EL(ehc)=Rho(gahc) + Rho = torch.tensordot(Rho,T,([0,2],[0,1])) # Rho(gahc)*T(ghdb)=Rho(acdb) + Rho = Rho.permute(0,3,1,2).contiguous().view(dimE*dimT, dimE*dimT) # Rho(acdb)->Rho(ab;cd) + + Rho = Rho+Rho.t() + Rho = Rho/Rho.norm() + + # step 2: Get Isometry P + U, S, V = torch.svd(Rho) + truncation_error = S[D_new:].sum()/S.sum() + P = U[:, :D_new] # projection operator + + #can also do symeig since Rho is symmetric + #S, U = symeig(Rho) + #sorted, indices = torch.sort(S.abs(), descending=True) + #truncation_error = sorted[D_new:].sum()/sorted.sum() + #S = S[indices][:D_new] + #P = U[:, indices][:, :D_new] # projection operator + + # step 3: renormalize C and E + C = (P.t() @ Rho @ P) #C(D_new, D_new) + + ## EL(u,r,d) + P = P.view(dimE,dimT,D_new) + E = torch.tensordot(E, P, ([0],[0])) # EL(def)P(dga)=E(efga) + E = torch.tensordot(E, T, ([0,2],[1,0])) # E(efga)T(gehb)=E(fahb) + E = torch.tensordot(E, P, ([0,2],[0,1])) # E(fahb)P(fhc)=E(abc) + + # step 4: symmetrize C and E + C = 0.5*(C+C.t()) + E = 0.5*(E + E.permute(2, 1, 0)) + + return C/C.norm(), E, S.abs()/S.abs().max(), truncation_error + + +def CTMRG(T, chi, max_iter, epsilon, use_checkpoint=False): + # T(up, left, down, right) + + + # C(down, right), E(up,right,down) + C = T.sum((0,1)) # + E = T.sum(1).permute(0,2,1) + + truncation_error = 0.0 + sold = torch.zeros(chi, dtype=T.dtype, device=T.device) + diff = 1E1 + for n in range(max_iter): + tensors = C, E, T, torch.tensor(chi) + if use_checkpoint: # use checkpoint to save memory + C, E, s, error = checkpoint(renormalize, *tensors) + else: + C, E, s, error = renormalize(*tensors) + + Enorm = E.norm() + E = E/Enorm + truncation_error += error.item() + if (s.numel() == sold.numel()): + diff = (s-sold).norm().item() + #print( s, sold ) + #print( 'n: %d, Enorm: %g, error: %e, diff: %e' % (n, Enorm, error.item(), diff) ) + if (diff < epsilon): + break + sold = s + #print ('---->ctmrg converged at iterations %d to %.5e, truncation error: %.5f'%(n, diff, truncation_error/n)) + + return C, E + +if __name__=='__main__': + import time + torch.manual_seed(42) + D = 6 + chi = 80 + max_iter = 100 + device = 'cpu' + + # T(u,l,d,r) + T = torch.randn(D, D, D, D, dtype=torch.float64, device=device, requires_grad=True) + + T = (T + T.permute(0, 3, 2, 1))/2. # left-right symmetry + T = (T + T.permute(2, 1, 0, 3))/2. # up-down symmetry + T = (T + T.permute(3, 2, 1, 0))/2. # skew-diagonal symmetry + T = (T + T.permute(1, 0, 3, 2))/2. # digonal symmetry + T = T/T.norm() + + C, E = CTMRG(T, chi, max_iter, use_checkpoint=True) + C, E = CTMRG(T, chi, max_iter, use_checkpoint=False) + print( 'diffC = ', torch.dist( C, C.t() ) ) + print( 'diffE = ', torch.dist( E, E.permute(2,1,0) ) ) + diff --git a/tensornets/dominanteigensolver.py b/tensornets/dominanteigensolver.py new file mode 100644 index 0000000..9639933 --- /dev/null +++ b/tensornets/dominanteigensolver.py @@ -0,0 +1,13 @@ +import torch +from scipy.sparse.linalg import eigs +from scipy.sparse.linalg import LinearOperator + +def dominant_eigensolver(Hopt, v0, tol): + N = v0.shape[0] + A = LinearOperator((N, N), matvec=lambda x: Hopt(torch.as_tensor(x)).detach().numpy()) + vals, vecs = eigs(A, k=1, which='LM', v0=v0.detach().numpy(), tol=tol) + + w = torch.as_tensor(vals[0].real, dtype=v0.dtype, device=v0.device) + v = torch.as_tensor(vecs[:, 0].real, dtype=v0.dtype, device=v0.device) + return w, v + diff --git a/tensornets/fixgaugeqr.py b/tensornets/fixgaugeqr.py new file mode 100644 index 0000000..66cfd0b --- /dev/null +++ b/tensornets/fixgaugeqr.py @@ -0,0 +1,28 @@ +import torch + + +def fixgauge_qr(A): + Q, R = torch.qr(A) + Rsgn = torch.diag(R).sign() + Q = Q * Rsgn + R = Rsgn[:, None] * R + + return Q, R + + +def test_FixQR(): + M, N = 10, 4 + # torch.manual_seed(2) + A = torch.randn(M, N, dtype=torch.float64) + + Q0, R0 = torch.qr(A) + Q1, R1 = fixgauge_qr(A) + + print('R0 = ', R0) + print('R1 = ', R1) + print('qr0 = ', torch.dist(Q0 @ R0, A)) + print('qr1 = ', torch.dist(Q1 @ R1, A)) + + +if __name__ == "__main__": + test_FixQR() diff --git a/tensornets/orthonormalize.py b/tensornets/orthonormalize.py new file mode 100644 index 0000000..4ea74f1 --- /dev/null +++ b/tensornets/orthonormalize.py @@ -0,0 +1,71 @@ +import torch + +from .fixgaugeqr import fixgauge_qr +from .dominanteigensolver import dominant_eigensolver + + +def left_orthonormalize(A, C, epsilon, maxiter=100): + ''' + C*A = lambda*AL*C + ''' + chi, d = A.shape[0], A.shape[1] + A = A.contiguous().view(chi, d * chi) + + def step(C): + _, C = fixgauge_qr(C) + C = C / C.norm() + AC = (C @ A).view(chi * d, chi) + AL, Cnew = fixgauge_qr(AC) + lamb = Cnew.norm() + Cnew = Cnew / lamb + diff = torch.dist(Cnew, C) + return AL.view(chi, d, chi), Cnew, lamb, diff + + AL, Cnew, lamb, diff = step(C) + for it in range(maxiter): + def Hopt(x): + ''' + C = A^{T}_L C A + ''' + C = x.view(chi, chi) + AC = (C @ A).view(chi * d, chi) + y = AL.view(chi * d, chi).t() @ AC + return y.view(-1) + + _, C = dominant_eigensolver(Hopt, Cnew.contiguous().view(-1), diff / 10) + C = C.view(chi, chi) + + AL, Cnew, lamb, diff = step(C) + if (diff < epsilon): + break + else: + C = Cnew + return AL, C, lamb + + +def right_orthonormalize(A, C, epsilon): + ''' + A*C = lambda*C*AR + ''' + AL, C, lamb = left_orthonormalize(A.permute(2, 1, 0), C.permute(1, 0), epsilon) + return AL.permute(2, 1, 0).contiguous(), C.permute(1, 0).contiguous(), lamb + + +if __name__ == '__main__': + torch.manual_seed(42) + chi = 30 + d = 2 + epsilon = 1E-12 + dtype = torch.float64 + A = torch.randn(chi, d, chi, dtype=dtype) + + AL, C, lamb = left_orthonormalize(A, torch.eye(chi, dtype=dtype), epsilon) + print( + torch.dist((C @ A.view(chi, d * chi)).view(chi, d, chi), lamb * (AL.view(chi * d, chi) @ C).view(chi, d, chi))) + AL = AL.view(chi * d, chi) + print(torch.dist(AL.t() @ AL, torch.eye(chi, dtype=dtype))) + + # AR, C, lamb = right_orthonormalize(A, torch.eye(chi, dtype=dtype), epsilon) + # print (torch.dist((A.view(chi*d, chi)@C).view(-1), lamb*(C@AR.view(chi, d*chi)).view(-1))) + # AR = AR.view(chi, d*chi) + # print (AR@AR.t()) diff --git a/tensornets/singlelayer.py b/tensornets/singlelayer.py new file mode 100644 index 0000000..a2eb556 --- /dev/null +++ b/tensornets/singlelayer.py @@ -0,0 +1,48 @@ +import torch + +from .adlib import SVD +svd = SVD.apply +from .ctmrg import CTMRG +from .vumps import vumps_run, vumps_calc + +def projection(T, epsilon=1E-3): + + D = T.shape[0] # double layer bond dimension + + M = T.view(D, -1) + M = M@M.t() + + U, S, _ = svd(M) + + #S = S/S.max() + #chi = (S>epsilon).sum().item() + + #up to truncation error + chi = (torch.cumsum(S, dim=0)/S.sum() <= 1-epsilon).sum().item() + + U = U[:, :chi].view(D, chi) + #print (S/S.max()) + #print (torch.cumsum(S, dim=0)/S.sum() ) + print ('---->truncated from', D, 'to', chi) + + return torch.einsum('abcd,ai,bj,ck,dl->ijkl', (T, U, U, U, U)), U + +def ctmrg(T, chi, maxiter, epsilon): + + with torch.no_grad(): + C, E = CTMRG(T, chi, maxiter, epsilon) + + Z1 = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C)) + Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C)) + Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C)) + lnZ = torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs()) + + return lnZ + +def vumps(T, chi, maxiter, epsilon): + with torch.no_grad(): + _, AC, F, C, _ = vumps_run(T, chi, epsilon, maxiter) + + lnZ = torch.log(vumps_calc(T, AC, F, C)) + return lnZ + diff --git a/tensornets/tests/.test_eigh.py b/tensornets/tests/.test_eigh.py new file mode 100644 index 0000000..c35d7c2 --- /dev/null +++ b/tensornets/tests/.test_eigh.py @@ -0,0 +1,14 @@ +import numpy as np +import torch +import sys, os +testdir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(testdir+"/..") + +from adlib.eigh import EigenSolver + +def test_eigs(): + M = 2 + torch.manual_seed(42) + A = torch.rand(M, M, dtype=torch.float64) + A = torch.nn.Parameter(A+A.t()) + assert(torch.autograd.gradcheck(EigenSolver.apply, A, eps=1e-6, atol=1e-4)) diff --git a/tensornets/tests/test_power.py b/tensornets/tests/test_power.py new file mode 100644 index 0000000..38aa231 --- /dev/null +++ b/tensornets/tests/test_power.py @@ -0,0 +1,38 @@ +import numpy as np +import torch +import sys, os +testdir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(testdir+"/..") + +from adlib.power import FixedPoint + +def test_backward(): + N = 4 + torch.manual_seed(2) + A = torch.rand(N, N, dtype=torch.float64, requires_grad=True) + x0 = torch.rand(N, dtype=torch.float64) + x0 = x0/x0.norm() + tol = 1E-10 + + input = A, x0, tol + assert(torch.autograd.gradcheck(FixedPoint.apply, input, eps=1E-6, atol=tol)) + +def test_forward(): + torch.manual_seed(42) + N = 100 + tol = 1E-8 + dtype = torch.float64 + A = torch.randn(N, N, dtype=dtype) + A = A+A.t() + + w, v = torch.symeig(A, eigenvectors=True) + idx = torch.argmax(w.abs()) + + v_exact = v[:, idx] + v_exact = v_exact[0].sign() * v_exact + + x0 = torch.rand(N, dtype=dtype) + x0 = x0/x0.norm() + x = FixedPoint.apply(A, x0, tol) + + assert(torch.allclose(v_exact, x, rtol=tol, atol=tol)) diff --git a/tensornets/tests/test_qr.py b/tensornets/tests/test_qr.py new file mode 100644 index 0000000..2dd4287 --- /dev/null +++ b/tensornets/tests/test_qr.py @@ -0,0 +1,14 @@ +import numpy as np +import torch +import sys, os +testdir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(testdir+"/..") + +from adlib.qr import QR + +def test_qr(): + M, N = 4, 6 + torch.manual_seed(2) + A = torch.randn(M, N, dtype=torch.float64) + A.requires_grad=True + assert(torch.autograd.gradcheck(QR.apply, A, eps=1e-4, atol=1e-2)) diff --git a/tensornets/tests/test_svd.py b/tensornets/tests/test_svd.py new file mode 100644 index 0000000..40fb84a --- /dev/null +++ b/tensornets/tests/test_svd.py @@ -0,0 +1,13 @@ +import numpy as np +import torch +import sys, os +testdir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(testdir+"/..") + +from adlib.svd import SVD + +def test_svd(): + M, N = 20, 16 + torch.manual_seed(2) + input = torch.rand(M, N, dtype=torch.float64, requires_grad=True) + assert(torch.autograd.gradcheck(SVD.apply, input, eps=1e-6, atol=1e-4)) diff --git a/tensornets/torchncon.py b/tensornets/torchncon.py new file mode 100644 index 0000000..34825d2 --- /dev/null +++ b/tensornets/torchncon.py @@ -0,0 +1,330 @@ +import torch +import collections + +""" A module for the function ncon, which does contractions of several tensors. +""" + + +def ncon(AA, v, order=None, forder=None, check_indices=True): + """ AA = [A1, A2, ..., Ap] list of tensors. + + v = (v1, v2, ..., vp) tuple of lists of indices e.g. v1 = [3 4 -1] labels + the three indices of tensor A1, with -1 indicating an uncontracted index + (open leg) and 3 and 4 being the contracted indices. + + order, if present, contains a list of all positive indices - if not + [1 2 3 4 ...] by default. This is the order in which they are contracted. + + forder, if present, contains the final ordering of the uncontracted indices + - if not, [-1 -2 ..] by default. + + There is some leeway in the way the inputs are given. For example, + instead of giving a list of tensors as the first argument one can + give some different iterable of tensors, such as a tuple, or a + single tensor by itself (anything that has the attribute "shape" + will be considered a tensor). + """ + + # We want to handle the tensors as a list, regardless of what kind + # of iterable we are given. In addition, if only a single element is + # given, we make list out of it. Inputs are assumed to be non-empty. + if hasattr(AA, "shape"): + AA = [AA] + else: + AA = list(AA) + v = list(v) + if not isinstance(v[0], collections.Iterable): + # v is not a list of lists, so make it such. + v = [v] + else: + v = list(map(list, v)) + + if order == None: + order = create_order(v) + if forder == None: + forder = create_forder(v) + + if check_indices: + # Raise a RuntimeError if the indices are wrong. + do_check_indices(AA, v, order, forder) + + # If the graph is dinconnected, connect it with trivial indices that + # will be contracted at the very end. + connect_graph(AA, v, order) + + while len(order) > 0: + tcon = get_tcon(v, order[0]) # tcon = tensors to be contracted + # Find the indices icon that are to be contracted. + if len(tcon) == 1: + tracing = True + icon = [order[0]] + else: + tracing = False + icon = get_icon(v, tcon) + # Position in tcon[0] and tcon[1] of indices to be contracted. + # In the case of trace, pos2 = [] + pos1, pos2 = get_pos(v, tcon, icon) + if tracing: + # Trace on a tensor + new_A = trace(AA[tcon[0]], axis1=pos1[0], axis2=pos1[1]) + else: + # Contraction of 2 tensors + new_A = con(AA[tcon[0]], AA[tcon[1]], (pos1, pos2)) + AA.append(new_A) + v.append(find_newv(v, tcon, icon)) # Add the v for the new tensor + for i in sorted(tcon, reverse=True): + # Delete the contracted tensors and indices from the lists. + # tcon is reverse sorted so that tensors are removed starting from + # the end of AA, otherwise the order would get messed. + del AA[i] + del v[i] + order = renew_order(order, icon) # Update order + + vlast = v[0] + A = AA[0] + A = permute_final(A, vlast, forder) + return A + + +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # + +def create_order(v): + """ Identify all unique, positive indices and return them sorted. """ + flat_v = sum(v, []) + x = [i for i in flat_v if i > 0] + # Converting to a set and back removes duplicates + x = list(set(x)) + return sorted(x) + + +def create_forder(v): + """ Identify all unique, negative indices and return them reverse sorted + (-1 first). + """ + flat_v = sum(v, []) + x = [i for i in flat_v if i < 0] + # Converting to a set and back removes duplicates + x = list(set(x)) + return sorted(x, reverse=True) + + +def connect_graph(AA, v, order): + """ Connect the graph of tensors to be contracted by trivial + indices, if necessary. Add these trivial indices to the end of the + contraction order. + + AA, v and order are modified in place. + """ + # Build ccomponents, a list of the connected components of the graph, + # where each component is represented by a a set of indices. + unvisited = set(range(len(AA))) + visited = set() + ccomponents = [] + while unvisited: + component = set() + next_visit = unvisited.pop() + to_visit = {next_visit} + while to_visit: + i = to_visit.pop() + unvisited.discard(i) + component.add(i) + visited.add(i) + # Get the indices of tensors neighbouring AA[i]. + i_inds = set(v[i]) + neighs = (j for j, j_inds in enumerate(v) if i_inds.intersection(j_inds)) + for neigh in neighs: + if neigh not in visited: + to_visit.add(neigh) + ccomponents.append(component) + # If there is more than one connected component, take one of them, a + # take an arbitrary tensor (called c) out of it, and connect that + # tensor with an arbitrary tensor (called d) from all the other + # components using a trivial index. + c = ccomponents.pop().pop() + while ccomponents: + d = ccomponents.pop().pop() + A_c = AA[c] + A_d = AA[d] + c_axis = len(v[c]) + d_axis = len(v[d]) + try: + AA[c] = A_c.expand_dims(c_axis, direction=1) + except AttributeError: + AA[c] = np.expand_dims(A_c, c_axis) + try: + AA[d] = A_d.expand_dims(d_axis, direction=-1) + except AttributeError: + AA[d] = np.expand_dims(A_d, d_axis) + try: + dim_num = max(order) + 1 + except ValueError: + dim_num = 1 + v[c].append(dim_num) + v[d].append(dim_num) + order.append(dim_num) + return None + + +def get_tcon(v, index): + """ Gets the list indices in AA of the tensors that have index as their + leg. + """ + tcon = [] + for i, inds in enumerate(v): + if index in inds: + tcon.append(i) + l = len(tcon) + # If check_indices is called and it does its work properly then these + # checks should in fact be unnecessary. + if l > 2: + raise ValueError('In ncon.get_tcon, more than two tensors share a ' + 'contraction index.') + elif l < 1: + raise ValueError('In ncon.get_tcon, less than one tensor share a ' + 'contraction index.') + elif l == 1: + # The contraction is a trace. + how_many = v[tcon[0]].count(index) + if how_many != 2: + # Only one tensor has this index but it is not a trace because it + # does not occur twice for that tensor. + raise ValueError('In ncon.get_tcon, a trace index is listed ' + '!= 2 times for the same tensor.') + return tcon + + +def get_icon(v, tcon): + """ Returns a list of indices that are to be contracted when contractions + between the two tensors numbered in tcon are contracted. """ + inds1 = v[tcon[0]] + inds2 = v[tcon[1]] + icon = set(inds1).intersection(inds2) + icon = list(icon) + return icon + + +def get_pos(v, tcon, icon): + """ Get the positions of the indices icon in the list of legs the tensors + tcon to be contracted. + """ + pos1 = [[i for i, x in enumerate(v[tcon[0]]) if x == e] for e in icon] + pos1 = sum(pos1, []) + if len(tcon) < 2: + pos2 = [] + else: + pos2 = [[i for i, x in enumerate(v[tcon[1]]) if x == e] for e in icon] + pos2 = sum(pos2, []) + return pos1, pos2 + + +def find_newv(v, tcon, icon): + """ Find the list of indices for the new tensor after contraction of + indices icon of the tensors tcon. + """ + if len(tcon) == 2: + newv = v[tcon[0]] + v[tcon[1]] + else: + newv = v[tcon[0]] + newv = [i for i in newv if i not in icon] + return newv + + +def renew_order(order, icon): + """ Returns the new order with the contracted indices removed from it. """ + return [i for i in order if i not in icon] + + +def permute_final(A, v, forder): + """ Returns the final tensor A with its legs permuted to the order given + in forder. + """ + perm = [v.index(i) for i in forder] + return A.permute(tuple(perm)) + + +def do_check_indices(AA, v, order, forder): + """ Check that + 1) the number of tensors in AA matches the number of index lists in v. + 2) every tensor is given the right number of indices. + 3) every contracted index is featured exactly twice and every free index + exactly once. + 4) the dimensions of the two ends of each contracted index match. + """ + + # 1) + if len(AA) != len(v): + raise ValueError(('In ncon.do_check_indices, the number of tensors %i' + ' does not match the number of index lists %i') + % (len(AA), len(v))) + + # 2) + # Create a list of lists with the shapes of each A in AA. + shapes = list(map(lambda A: list(A.shape), AA)) + for i, inds in enumerate(v): + if len(inds) != len(shapes[i]): + raise ValueError(('In ncon.do_check_indices, len(v[%i])=%i ' + 'does not match the numbers of indices of ' + 'AA[%i] = %i') % (i, len(inds), i, + len(shapes[i]))) + + # 3) and 4) + # v_pairs = [[(0,0), (0,1), (0,2), ...], [(1,0), (1,1), (1,2), ...], ...] + v_pairs = [[(i, j) for j in range(len(s))] for i, s in enumerate(v)] + v_pairs = sum(v_pairs, []) + v_sum = sum(v, []) + # For t, o in zip(v_pairs, v_sum) t is the tuple of the number of + # the tensor and the index and o is the contraction order of that + # index. We group these tuples by the contraction order. + order_groups = [[t for t, o in zip(v_pairs, v_sum) if o == e] + for e in order] + forder_groups = [[1 for fo in v_sum if fo == e] for e in forder] + for i, o in enumerate(order_groups): + if len(o) != 2: + raise ValueError(('In ncon.do_check_indices, the contracted index ' + '%i is not featured exactly twice in v.') % order[i]) + else: + A0, ind0 = o[0] + A1, ind1 = o[1] + try: + compatible = AA[A0].compatible_indices(AA[A1], ind0, ind1) + except AttributeError: + compatible = AA[A0].shape[ind0] == AA[A1].shape[ind1] + if not compatible: + raise ValueError('In ncon.do_check_indices, for the ' + 'contraction index %i, the leg %i of tensor ' + 'number %i and the leg %i of tensor number ' + '%i are not compatible.' + % (order[i], ind0, A0, ind1, A1)) + for i, fo in enumerate(forder_groups): + if len(fo) != 1: + raise ValueError(('In ncon.do_check_indices, the free index ' + '%i is not featured exactly once in v.') % forder[i]) + + # All is well if we made it here. + return True + + +#################################################################### +# The following are simple wrappers around numpy/Tensor functions, # +# but may be replaced with fancier stuff later. # +#################################################################### + + +def con(A, B, inds): + if torch.is_tensor(A) and torch.is_tensor(B): + return torch.tensordot(A, B, inds) + else: + return A.dot(B, inds) + + +def trace(A, axis1=0, axis2=1): + return A.trace(axis1=axis1, axis2=axis2) + + +if __name__ == '__main__': + A = torch.randn(2, 3, 4, 5, dtype=torch.float64) + B = torch.randn(4, 5, 2, dtype=torch.float64) + C_ncon = ncon([A, B], ([-1, -2, 1, 2], [1, 2, -3]), [1, 2], [-3, -1, -2]) + C_einsum = torch.einsum('abcd,cde->eab', (A, B)) + + print((C_ncon - C_einsum).abs().sum()) diff --git a/tensornets/trg.py b/tensornets/trg.py new file mode 100644 index 0000000..a51140a --- /dev/null +++ b/tensornets/trg.py @@ -0,0 +1,41 @@ +import torch + +from .adlib import SVD +svd = SVD.apply + +def renormalize(*args): + T, chi, epsilon = args + + D = T.shape[0] + Ma = T.view(D**2, D**2) + Mb = T.permute(1, 2, 0, 3).contiguous().view(D**2, D**2) + + Ua, Sa, Va = svd(Ma) + Ub, Sb, Vb = svd(Mb) + + D_new = min(min(D**2, chi), min((Sa>epsilon).sum().item(), (Sb>epsilon).sum().item())) + + S1 = (Ua[:, :D_new]* torch.sqrt(Sa[:D_new])).view(D, D, D_new) + S3 = (Va[:, :D_new]* torch.sqrt(Sa[:D_new])).view(D, D, D_new) + S2 = (Ub[:, :D_new]* torch.sqrt(Sb[:D_new])).view(D, D, D_new) + S4 = (Vb[:, :D_new]* torch.sqrt(Sb[:D_new])).view(D, D, D_new) + + return torch.einsum('xwu,yxl,yzd,wzr->uldr', (S2, S3, S4, S1)) + +def TRG(T, chi, no_iter, epsilon=1E-15): + + lnZ = 0.0 + for n in range(no_iter): + maxval = T.abs().max() + T = T/maxval + lnZ += 2**(no_iter-n)*torch.log(maxval) + T = renormalize(T, chi, epsilon) + + trace = 0.0 + for x in range(T.shape[0]): + for y in range(T.shape[1]): + trace += T[x, y, x, y] + lnZ += torch.log(trace) + + return lnZ/2**no_iter + diff --git a/tensornets/utils.py b/tensornets/utils.py new file mode 100644 index 0000000..b914e28 --- /dev/null +++ b/tensornets/utils.py @@ -0,0 +1,63 @@ +import re +import torch + +def kronecker_product(t1, t2): + """ + Computes the Kronecker product between two tensors. + See https://en.wikipedia.org/wiki/Kronecker_product + """ + t1_height, t1_width = t1.size() + t2_height, t2_width = t2.size() + out_height = t1_height * t2_height + out_width = t1_width * t2_width + + tiled_t2 = t2.repeat(t1_height, t1_width) + expanded_t1 = ( + t1.unsqueeze(2) + .unsqueeze(3) + .repeat(1, t2_height, t2_width, 1) + .view(out_height, out_width) + ) + + return expanded_t1 * tiled_t2 + +def symmetrize(A): + ''' + A(phy, up, left, down, right) + left-right, up-down, diagonal symmetrize + ''' + Asymm = (A + A.permute(0, 1, 4, 3, 2))/2. # left-right symmetry + Asymm = (Asymm + Asymm.permute(0, 3, 2, 1, 4))/2. # up-down symmetry + Asymm = (Asymm + Asymm.permute(0, 4, 3, 2, 1))/2. # skew-diagonal symmetry + Asymm = (Asymm + Asymm.permute(0, 2, 1, 4, 3))/2. # diagonal symmetry + + return Asymm/Asymm.norm() + +def save_checkpoint(checkpoint_path, model, optimizer): + state = {'state_dict': model.state_dict(), + 'optimizer' : optimizer.state_dict()} + #print(model.state_dict().keys()) + torch.save(state, checkpoint_path) + #print('model saved to %s' % checkpoint_path) + +def load_checkpoint(checkpoint_path, args, model): + print( 'load old model from %s ' % checkpoint_path ) + print( 'Dold = ', re.search('_D([0-9]*)_', checkpoint_path).group(1) ) + Dold = int(re.search('_D([0-9]*)_', checkpoint_path).group(1)) + + d, D = args.d, args.D + dtype, device = model.A.dtype, model.A.device + + if (Dold != D): + B = torch.rand( d, Dold, Dold, Dold, Dold, dtype=dtype, device=device) + model.A = torch.nn.Parameter(B) + + state = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + model.load_state_dict(state['state_dict']) + + if (Dold != D): + Aold = model.A.data + B = 1E-2*torch.rand( d, D, D, D, D, dtype=dtype, device=device) + B[:, :Dold, :Dold, :Dold, :Dold] = Aold.reshape(d, Dold, Dold, Dold, Dold) + model.A = torch.nn.Parameter(B) + diff --git a/tensornets/vumps.py b/tensornets/vumps.py new file mode 100644 index 0000000..033689a --- /dev/null +++ b/tensornets/vumps.py @@ -0,0 +1,104 @@ +import torch +import time + +import numpy as np +from .torchncon import ncon +from .orthonormalize import left_orthonormalize +from .fixgaugeqr import fixgauge_qr +from .dominanteigensolver import dominant_eigensolver + +def symmetrize(A): + Asymm = A + A.permute(2, 1, 0) + return Asymm / Asymm.norm() + +def H_F(AL, T, x): + chi, d = AL.shape[0], AL.shape[1] + x = x.view(chi, d, chi).to(T.device) + y = ncon([x, T, AL, AL], ([2, 1, 3], [4, 1, 5, -2], [2, 4, -1], [3, 5, -3])) + return y.contiguous().view(-1).to('cpu') + + +def H_AC(F, T, x): + chi, d = F.shape[0], F.shape[1] + x = x.view(chi, d, chi).to(T.device) + y = ncon([x, T, F, F], ([2, 1, 3], [1, 4, -2, 5], [2, 4, -1], [3, 5, -3])) + return y.contiguous().view(-1).to('cpu') + + +def H_C(F, x): + chi, d = F.shape[0], F.shape[1] + x = x.view(chi, chi).to(F.device) + y = ncon([x, F, F], ([1, 2], [1, 3, -1], [2, 3, -2])) + return y.contiguous().view(-1).to('cpu') + + +def vumps_run(T, chi, epsilon, maxiter=100, Cinit=None, ALinit=None, ACinit=None, Finit=None): + d = T.shape[0] + dtype, device = T.dtype, T.device + + if Cinit is None: + F = torch.rand(chi * d * chi, dtype=dtype, device=device) + F = symmetrize(F.view(chi, d, chi)) + + A = torch.rand(chi, d, chi, dtype=dtype, device=device) + A = symmetrize(A) + + AL, C, _ = left_orthonormalize(A, torch.eye(chi, chi, dtype=dtype, device=device), epsilon / 10) + + AC = C @ A.view(chi, d * chi) + AC = AC.view(chi * d, chi) @ C.t() + AC = symmetrize(AC.view(chi, d, chi)) + else: + F = Finit.detach() + AL = ALinit.detach() + AC = ACinit.detach() + C = Cinit.detach() + # start = time.time() + # rtn_history = [] + + for it in range(maxiter): + print ('it', it, maxiter, epsilon) + lambF, F = dominant_eigensolver(lambda x: H_F(AL, T, x), F.view(-1).to('cpu'), tol=epsilon / 10) + print ('got F') + F = symmetrize(F.view(chi, d, chi)).to(device) + + # compute error + MAC = H_AC(F, T, (AL.view(chi * d, chi) @ C).view(-1)).view(chi, d, chi).to(device) + MAC = MAC - ncon([MAC, AL, AL], ([1, 2, -3], [1, 2, 3], [-1, -2, 3])) + diff = MAC.norm() + + lambAC, AC = dominant_eigensolver(lambda x: H_AC(F, T, x), AC.view(-1).to('cpu'), tol=epsilon / 10) + + print ('got AC') + AC = symmetrize(AC.view(chi, d, chi).to(device)) + + lambC, C = dominant_eigensolver(lambda x: H_C(F, x), C.view(-1).to('cpu'), tol=epsilon / 10) + + print ('got C') + C = C.view(chi, chi).to(device) + C = (C + C.t()) / 2 + C = C / C.norm() + + QAC, RAC = fixgauge_qr(AC.view(chi * d, chi)) + QC, RC = fixgauge_qr(C) + print ('qr done') + AL = (QAC @ QC.t()).view(chi, d, chi) + # diff = torch.dist(RAC, RC) + + # print(it, diff.item(), torch.log(lambF).item(), torch.log(lambAC / lambC).item()) + # if it % 5 == 0: + # rtn_history.append([time.time() - start, torch.log(lambAC / lambC).item()]) + if (diff < epsilon): + break + + if diff >= epsilon: + print('WARNING: vumps reaching maxiter, diff=', diff.item()) + + return torch.log(lambAC / lambC), AC, F, C, AL + + +def vumps_calc(T, AC, F, C): + return torch.dot(H_F(AC, T, F).to(F.device), F.view(-1)) \ + / torch.dot(AC.view(-1), AC.view(-1)) \ + / torch.dot(H_C(F, C).to(F.device), C.view(-1)) + -- GitLab