Advanced Computing Platform for Theoretical Physics

commit大文件会使得服务器变得不稳定，请大家尽量只commit代码，不要commit大的文件。

Commit e0459b6a by Lei Wang

### backprop through lanczos is slower than forward; lanczos is faster than symeig when Dcut>50

parent f1e0e64e
lanczos.py 0 → 100755
 import torch def lanczos(Hopt, phi0, matsize): ''' give function Hopt initial vector (must be normalized) return Hlanc and laczos vectors ''' dim = phi0.shape[0] #Have = torch.zeros(matsize, dtype=phi0.dtype, device=phi0.device) #Bb = torch.zeros(matsize, dtype=phi0.dtype, device=phi0.device) #Phi = torch.zeros((matsize,dim), dtype=phi0.dtype, device=phi0.device) Have = [torch.tensor([0.0], dtype=phi0.dtype, device=phi0.device) for _ in range(matsize)] Bb = [torch.tensor([0.0], dtype=phi0.dtype, device=phi0.device) for _ in range(matsize)] Phi = [torch.zeros(dim, dtype=phi0.dtype, device=phi0.device) for _ in range(matsize)] Phi[0]= phi0 for i in range(matsize): phitemp = Hopt(Phi[i]) Have[i] = (Phi[i]*phitemp).sum() if i=H|Phi_0>-a_0|Phi_0> else: Phi[i+1]=phitemp-Have[i]*Phi[i]-Bb[i]*Phi[i-1] #|Phi_n+1>=H|Phi_n>-a_n|Phi_n>-b_n|Phi_n-1> #--- Schmit orthogonalization--------------------------------- orth = [(Phi[m]*Phi[i+1]).sum() for m in range(i+1)] for m in range(i+1): Phi[i+1] = Phi[i+1] - Phi[m] * orth[m] #--- Schmit orthogonalization--------------------------------- res = (Phi[i+1]*Phi[i+1]).sum() if(res<1E-12): matsize=i+1 break Bb[i+1] = torch.sqrt(res) Phi[i+1] = Phi[i+1]/Bb[i+1] #normalize |Phi_n+1> Hlanc = torch.diag(torch.stack(Have[:matsize]), 0) \ + torch.diag(torch.stack(Bb[1:matsize]), -1) \ + torch.diag(torch.stack(Bb[1:matsize]), 1) w, _ = torch.symeig(Hlanc, eigenvectors=True) return w if __name__=='__main__': import time import argparse parser = argparse.ArgumentParser(description='') parser.add_argument("-float32", action='store_true', help="use float32") parser.add_argument("-cuda", type=int, default=-1, help="use GPU") args = parser.parse_args() device = torch.device("cpu" if args.cuda<0 else "cuda:"+str(args.cuda)) dtype = torch.float32 if args.float32 else torch.float64 N = 1000 M = 100 Ham = torch.randn(N, N, dtype=dtype, device=device) Ham = (Ham + Ham.t())/2. print (Ham.shape) t0 = time.time() w, _ = torch.symeig(Ham, eigenvectors=True) print (w[:10]) print (w[-10:]) print ('full diag', time.time()- t0) #lanczos Hopt = lambda x: torch.mv(Ham,x) t0 = time.time() phi0 = torch.randn(N, dtype=dtype, device=device) phi0 = phi0/phi0.norm() w = lanczos(Hopt, phi0, M) print (w[:10]) print (w[-10:]) print ('lanczos', time.time()- t0)
 import torch torch.set_num_threads(4) from lanczos import lanczos def mpsrg(A, T): ... ... @@ -10,8 +11,15 @@ def mpsrg(A, T): C = torch.einsum('ldr,idj->lirj', (Asymm, Asymm)).contiguous().view(D**2, D**2) w, _ = torch.symeig(B, eigenvectors=True) #phi0 = Asymm.view(D**2*d) #phi0 = phi0/phi0.norm() #w = lanczos(lambda x: torch.mv(B,x), phi0, 100) lnZ1 = torch.log(w.abs().max()) w, _ = torch.symeig(C, eigenvectors=True) #phi0 = Asymm.sum(1).view(D**2) #phi0 = phi0/phi0.norm() #w = lanczos(lambda x: torch.mv(C,x), phi0, 100) lnZ2 = torch.log(w.abs().max()) #lnZ1 = 0.0 ... ... @@ -44,13 +52,13 @@ def vmps(T, d, D, Nepochs=50, Ainit=None): optimizer.zero_grad() #print ('einsum', time.time()- t0) #print ((B-B.t()).abs().sum(), (C-C.t()).abs().sum()) #t0 = time.time() t0 = time.time() loss = mpsrg(A, T.detach()) # loss = -lnZ , here we optimize over A #print ('mpsrg', time.time()- t0) #print (' loss', loss.item()) #t0 = time.time() print ('mpsrg', time.time()- t0) print (' loss', loss.item()) t0 = time.time() loss.backward(retain_graph=False) #print ('backward', time.time()- t0) print ('backward', time.time()- t0) return loss for epoch in range(Nepochs): ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!