From c21098b1b8033fbfc17313dbc2d931ec91cf6461 Mon Sep 17 00:00:00 2001 From: Lei Wang Date: Thu, 13 Dec 2018 22:18:51 +0800 Subject: [PATCH] lanczos can have matsize=1 --- lanczos.py | 14 ++++++++------ tests/test3.py | 18 ++++++++++++++++++ tests/test4.py | 19 +++++++++++++++++++ vmps.py | 15 +++++++++++---- 4 files changed, 56 insertions(+), 10 deletions(-) create mode 100644 tests/test3.py create mode 100644 tests/test4.py diff --git a/lanczos.py b/lanczos.py index f637a92..05c8637 100755 --- a/lanczos.py +++ b/lanczos.py @@ -44,14 +44,16 @@ def lanczos(Hopt, phi0, matsize): Bb[i+1] = torch.sqrt(res) Phi[i+1] = Phi[i+1]/Bb[i+1] #normalize |Phi_n+1> - - Hlanc = torch.diag(torch.stack(Have[:matsize]), 0) \ - + torch.diag(torch.stack(Bb[1:matsize]), -1) \ - + torch.diag(torch.stack(Bb[1:matsize]), 1) - w, _ = torch.symeig(Hlanc, eigenvectors=True) + + if matsize>1: + Hlanc = torch.diag(torch.stack(Have[:matsize]), 0) \ + + torch.diag(torch.stack(Bb[1:matsize]), -1) \ + + torch.diag(torch.stack(Bb[1:matsize]), 1) + w, _ = torch.symeig(Hlanc, eigenvectors=True) + else: + w = Have[0] return w - if __name__=='__main__': import time diff --git a/tests/test3.py b/tests/test3.py new file mode 100644 index 0000000..172f897 --- /dev/null +++ b/tests/test3.py @@ -0,0 +1,18 @@ +import torch + +D = 10 +d = 4 +A = torch.randn(D, d, D) + +def Hopt(x): + x = x.view(D, D) + return ((A.view(D*d, D) @ x).view(D, d*D) @ A.permute(1,2,0).contiguous().view(d*D, D)).view(D**2) + +C = torch.einsum('ldr,idj->lirj', (A, A)).contiguous().view(D**2, D**2) + +x = torch.randn(D**2) + +y = torch.mv(C, x) +z = Hopt(x) + +print ( (y-z).abs().sum() ) diff --git a/tests/test4.py b/tests/test4.py new file mode 100644 index 0000000..a3b40c6 --- /dev/null +++ b/tests/test4.py @@ -0,0 +1,19 @@ +import torch + +D = 10 +d = 2 +A = torch.randn(D, d, D, dtype=torch.float64) +T = torch.randn(d, d, d, d, dtype=torch.float64) + +def Hopt(x): + Tx = (T.view(-1, d) @ x.view(D, d, D).permute(1, 0, 2).contiguous().view(d,-1)).view(d,d,d,D,D).permute(1,3,0,2,4).contiguous() + return ((A.view(D, d*D)@Tx.view(d*D, d*d*D)).view(D*d, d*D) @ A.permute(1,2,0).contiguous().view(d*D, D)).view(D**2*d) + +B = torch.einsum('ldr,adcb,icj->lairbj', (A, T, A)).contiguous().view(D**2*d, D**2*d) + +x = torch.randn(D**2*d, dtype=torch.float64) + +y = torch.mv(B, x) +z = Hopt(x) + +print ( (y-z).abs().sum() ) diff --git a/vmps.py b/vmps.py index 428e8ea..1795b59 100644 --- a/vmps.py +++ b/vmps.py @@ -9,22 +9,29 @@ def mpsrg(A, T, use_lanczos=False): Asymm = (A + A.permute(2, 1, 0))*0.5 D, d = Asymm.shape[0], Asymm.shape[1] #t0 = time.time() - B = torch.einsum('ldr,adcb,icj->lairbj', (Asymm, T, Asymm)).contiguous().view(D**2*d, D**2*d) - C = torch.einsum('ldr,idj->lirj', (Asymm, Asymm)).contiguous().view(D**2, D**2) if use_lanczos: phi0 = Asymm.view(D**2*d) phi0 = phi0/phi0.norm() - w = lanczos(lambda x: torch.mv(B,x), phi0, 100) + def Hopt(x): + Tx = (T.view(-1, d) @ x.view(D, d, D).permute(1, 0, 2).contiguous().view(d,-1)).view(d,d,d,D,D).permute(1,3,0,2,4).contiguous() + return ((Asymm.view(D, d*D)@Tx.view(d*D, d*d*D)).view(D*d, d*D)@Asymm.permute(1,2,0).contiguous().view(d*D, D)).view(D**2*d) + + w = lanczos(Hopt, phi0, 100) else: + B = torch.einsum('ldr,adcb,icj->lairbj', (Asymm, T, Asymm)).contiguous().view(D**2*d, D**2*d) w, _ = torch.symeig(B, eigenvectors=True) lnZ1 = torch.log(w.abs().max()) if use_lanczos: phi0 = Asymm.sum(1).view(D**2) phi0 = phi0/phi0.norm() - w = lanczos(lambda x: torch.mv(C,x), phi0, 100) + def Hopt(x): + x = x.view(D, D) + return ((Asymm.view(D*d, D) @ x).view(D, d*D) @ Asymm.permute(1,2,0).contiguous().view(d*D, D)).view(D**2) + w = lanczos(Hopt, phi0, 100) else: + C = torch.einsum('ldr,idj->lirj', (Asymm, Asymm)).contiguous().view(D**2, D**2) w, _ = torch.symeig(C, eigenvectors=True) lnZ2 = torch.log(w.abs().max()) -- GitLab