Advanced Computing Platform for Theoretical Physics

Commit 8e0027d8 authored by Lei Wang's avatar Lei Wang
Browse files

split Ux Uy

parent cfd6ad61
......@@ -6,13 +6,14 @@ torch.set_num_threads(1)
from utils import trace
def renormalize(T, U):
chi, chi_new = U.shape[1], U.shape[2]
UTT = torch.einsum('abc,adeg,befh->cdfgh',(U, T, T)).contiguous().view(chi_new,chi,chi,chi**2)
def renormalize(T, Ux, Uy):
chi, chi_new = Ux.shape[1], Ux.shape[2]
UTT = torch.einsum('abc,adeg,befh->cdfgh',(Ux, T, T)).contiguous().view(chi_new,chi,chi,chi**2)
UTT2 = torch.einsum('abcd,gefd->abecfg', (UTT, UTT)).contiguous().view(chi_new, chi**2, chi**2, chi_new)
U = U.view(chi**2, chi_new)
Uy = Uy.view(chi**2, chi_new)
T = torch.einsum('abcd,be,cf->aefd', (UTT2, U, U)).contiguous()
T = torch.einsum('abcd,be,cf->aefd', (UTT2, Uy, Uy)).contiguous()
#T = torch.einsum('aem,bhn,fko,jlp,abcd,ecfg,gikl,dhij->mnop', (U, U, U, U, T, T, T, T))
return T
class Ising(torch.nn.Module):
......@@ -29,11 +30,13 @@ class Ising(torch.nn.Module):
chi = self.D
for step in range(niter):
chi_new = min(chi**2, self.chi)
self.params.append(torch.nn.Parameter(torch.randn(chi, chi, chi_new, dtype=dtype, device=device)))
for _ in range(2): #xy direction
U = torch.randn(chi, chi, chi_new, dtype=dtype, device=device)
self.params.append(torch.nn.Parameter(U))
chi = min(chi**2, self.chi)
def get_isometry(self, step):
return self.params[step]
return self.params[2*step], self.params[step*2+1]
def set_isometry(self, step, U):
self.params[step].data = U
......@@ -45,8 +48,8 @@ class Ising(torch.nn.Module):
lnZ += 4**(-n)*torch.log(f)
T = T / f
U = self.get_isometry(n) # isometry for this step
T = renormalize(T, U)
Ux, Uy = self.get_isometry(n) # isometry for this step
T = renormalize(T, Ux, Uy)
lnZ += torch.log(trace(T))/4.0**args.niter
return lnZ
......@@ -86,7 +89,7 @@ if __name__=='__main__':
p.requires_grad = False
sweeps = 0; step = 0; direction = 1
for _ in range(args.nsweeps*args.niter):
for _ in range(2*args.nsweeps*args.niter):
#sweep back and forth through isometries
turn_on_grad(step)
......@@ -115,7 +118,7 @@ if __name__=='__main__':
step += direction
#when reaching the end reverse sweep direction and report results
if (step==args.niter-1 or step==0):
if (step==2*args.niter-1 or step==0):
direction *= -1
sweeps += 1
with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment