Advanced Computing Platform for Theoretical Physics

Commit e0459b6a authored by Lei Wang's avatar Lei Wang
Browse files

backprop through lanczos is slower than forward; lanczos is faster than symeig when Dcut>50

parent f1e0e64e
import torch
def lanczos(Hopt, phi0, matsize):
'''
give function Hopt
initial vector (must be normalized)
return Hlanc and laczos vectors
'''
dim = phi0.shape[0]
#Have = torch.zeros(matsize, dtype=phi0.dtype, device=phi0.device)
#Bb = torch.zeros(matsize, dtype=phi0.dtype, device=phi0.device)
#Phi = torch.zeros((matsize,dim), dtype=phi0.dtype, device=phi0.device)
Have = [torch.tensor([0.0], dtype=phi0.dtype, device=phi0.device) for _ in range(matsize)]
Bb = [torch.tensor([0.0], dtype=phi0.dtype, device=phi0.device) for _ in range(matsize)]
Phi = [torch.zeros(dim, dtype=phi0.dtype, device=phi0.device) for _ in range(matsize)]
Phi[0]= phi0
for i in range(matsize):
phitemp = Hopt(Phi[i])
Have[i] = (Phi[i]*phitemp).sum()
if i<matsize-1:
if i==0:
Phi[i+1]=phitemp-Have[i]*Phi[i]
#|Phi_1>=H|Phi_0>-a_0|Phi_0>
else:
Phi[i+1]=phitemp-Have[i]*Phi[i]-Bb[i]*Phi[i-1]
#|Phi_n+1>=H|Phi_n>-a_n|Phi_n>-b_n|Phi_n-1>
#--- Schmit orthogonalization---------------------------------
orth = [(Phi[m]*Phi[i+1]).sum() for m in range(i+1)]
for m in range(i+1):
Phi[i+1] = Phi[i+1] - Phi[m] * orth[m]
#--- Schmit orthogonalization---------------------------------
res = (Phi[i+1]*Phi[i+1]).sum()
if(res<1E-12):
matsize=i+1
break
Bb[i+1] = torch.sqrt(res)
Phi[i+1] = Phi[i+1]/Bb[i+1] #normalize |Phi_n+1>
Hlanc = torch.diag(torch.stack(Have[:matsize]), 0) \
+ torch.diag(torch.stack(Bb[1:matsize]), -1) \
+ torch.diag(torch.stack(Bb[1:matsize]), 1)
w, _ = torch.symeig(Hlanc, eigenvectors=True)
return w
if __name__=='__main__':
import time
import argparse
parser = argparse.ArgumentParser(description='')
parser.add_argument("-float32", action='store_true', help="use float32")
parser.add_argument("-cuda", type=int, default=-1, help="use GPU")
args = parser.parse_args()
device = torch.device("cpu" if args.cuda<0 else "cuda:"+str(args.cuda))
dtype = torch.float32 if args.float32 else torch.float64
N = 1000
M = 100
Ham = torch.randn(N, N, dtype=dtype, device=device)
Ham = (Ham + Ham.t())/2.
print (Ham.shape)
t0 = time.time()
w, _ = torch.symeig(Ham, eigenvectors=True)
print (w[:10])
print (w[-10:])
print ('full diag', time.time()- t0)
#lanczos
Hopt = lambda x: torch.mv(Ham,x)
t0 = time.time()
phi0 = torch.randn(N, dtype=dtype, device=device)
phi0 = phi0/phi0.norm()
w = lanczos(Hopt, phi0, M)
print (w[:10])
print (w[-10:])
print ('lanczos', time.time()- t0)
import torch
torch.set_num_threads(4)
from lanczos import lanczos
def mpsrg(A, T):
......@@ -10,8 +11,15 @@ def mpsrg(A, T):
C = torch.einsum('ldr,idj->lirj', (Asymm, Asymm)).contiguous().view(D**2, D**2)
w, _ = torch.symeig(B, eigenvectors=True)
#phi0 = Asymm.view(D**2*d)
#phi0 = phi0/phi0.norm()
#w = lanczos(lambda x: torch.mv(B,x), phi0, 100)
lnZ1 = torch.log(w.abs().max())
w, _ = torch.symeig(C, eigenvectors=True)
#phi0 = Asymm.sum(1).view(D**2)
#phi0 = phi0/phi0.norm()
#w = lanczos(lambda x: torch.mv(C,x), phi0, 100)
lnZ2 = torch.log(w.abs().max())
#lnZ1 = 0.0
......@@ -44,13 +52,13 @@ def vmps(T, d, D, Nepochs=50, Ainit=None):
optimizer.zero_grad()
#print ('einsum', time.time()- t0)
#print ((B-B.t()).abs().sum(), (C-C.t()).abs().sum())
#t0 = time.time()
t0 = time.time()
loss = mpsrg(A, T.detach()) # loss = -lnZ , here we optimize over A
#print ('mpsrg', time.time()- t0)
#print (' loss', loss.item())
#t0 = time.time()
print ('mpsrg', time.time()- t0)
print (' loss', loss.item())
t0 = time.time()
loss.backward(retain_graph=False)
#print ('backward', time.time()- t0)
print ('backward', time.time()- t0)
return loss
for epoch in range(Nepochs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment