Advanced Computing Platform for Theoretical Physics

Commit 28563c09 authored by Lei Wang's avatar Lei Wang
Browse files

more careful about isometry shapes (this is useful for the first few isometries)

parent 555b2310
......@@ -20,22 +20,22 @@ class Ising(torch.nn.Module):
super(Ising, self).__init__()
self.D = 2
self.chi = chi
self.chi = chi # cutoff
self.niter = niter
self.dtype = dtype
self.device = device
self.params = torch.nn.ParameterList([torch.nn.Parameter(torch.randn(self.chi, self.chi, self.chi, dtype=dtype, device=device)) for _ in range(niter)])
self.params = torch.nn.ParameterList()
for step in range(niter):
chi = min(self.D**(2**(step//2)), self.chi)
chi_new = min(chi**2, self.chi)
self.params.append(torch.nn.Parameter(torch.randn(chi, chi, chi_new, dtype=dtype, device=device)))
def get_isometry(self, step):
chi = min(self.D**(2**(step//2)), self.chi)
chi_new = min(self.chi, chi**2)
return self.params[step][:chi, :chi, :chi_new]
return self.params[step]
def set_isometry(self, step, U):
chi = min(self.D**(2**(step//2)), self.chi)
chi_new = U.shape[-1]
self.params[step][:chi, :chi, :chi_new] = U
self.params[step].data = U
def init_isometry(self, T):
'''
......@@ -119,7 +119,7 @@ if __name__=='__main__':
for _ in range(args.nsweeps*args.niter):
#sweep back and forth through isometries
turn_on_grad(step)
for s in range(args.nstays): # local sweeps
model.zero_grad()
start = time.time()
......@@ -131,14 +131,23 @@ if __name__=='__main__':
#compute environment from gradient
#\partial tr(wE)/ \partial w = E^T
#cf http://www.matrixcalculus.org
E = model.params[step].grad
chi, chi, chi_new = E.shape
E = E.view(chi**2, chi_new)
E = model.params[step].grad
chi, chi_new = E.shape[0], E.shape[-1]
E = E.reshape(chi**2, chi_new)
#and perform mera update
U, S, V = torch.svd(E)
svd = time.time()
#check orthgonal condition of isometries
#isometry_last = model.params[step].clone()
#isometry_last = isometry_last.view(chi**2, chi_new)
model.params[step].data = (U@V.t()).view(chi, chi, chi_new)
#isometry = model.params[step].view(chi**2, chi_new)
#ovlp = isometry_last.t()@isometry
#I = torch.eye(chi_new, dtype=dtype, device=device)
#print (step, s, torch.dist(ovlp, I).item(), torch.dist(ovlp, -I).item())
#print(forward-start, backward - forward, svd-backward)
step += direction
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment