Advanced Computing Platform for Theoretical Physics

Commit 22dd9226 authored by Lei Wang's avatar Lei Wang
Browse files

trying to speedup einsum in ctmrg; but it is minor, so the bottneck is svd

parent e28bad7f
...@@ -16,6 +16,8 @@ def ctmrg(Tin, d, Dcut, no_iter): ...@@ -16,6 +16,8 @@ def ctmrg(Tin, d, Dcut, no_iter):
D = d D = d
for n in range(no_iter): for n in range(no_iter):
A = torch.einsum('ab,eca,bdg,cdfh->efgh', (C, E, E, T)).contiguous().view(D*d, D*d) A = torch.einsum('ab,eca,bdg,cdfh->efgh', (C, E, E, T)).contiguous().view(D*d, D*d)
A = (A+A.t())/2.
D_new = min(D*d, Dcut) D_new = min(D*d, Dcut)
U, S, V = torch.svd(A) U, S, V = torch.svd(A)
truncation_error += S[D_new:].sum()/S.sum() truncation_error += S[D_new:].sum()/S.sum()
...@@ -25,7 +27,12 @@ def ctmrg(Tin, d, Dcut, no_iter): ...@@ -25,7 +27,12 @@ def ctmrg(Tin, d, Dcut, no_iter):
C = (C+C.t())/2. C = (C+C.t())/2.
ET = torch.einsum('ldr,adbc->labrc', (E, T)).contiguous().view(D*d, d, D*d) ET = torch.einsum('ldr,adbc->labrc', (E, T)).contiguous().view(D*d, d, D*d)
E = torch.einsum('li,ldr,rj->idj', (P, ET, P)) #(D, d, D) #ET = torch.tensordot(E, T, dims=([1], [1]))
#ET = ET.permute(0, 2, 3, 1, 4).contiguous().view(D*d, d, D*d)
E = torch.einsum('li,ldr,rj->idj', (P, ET, P)) #(D_new, d, D_new)
#E = ( P.t() @ ((ET.view(D*d*d, D*d)@P).view(D*d, d*D_new))).view(D_new,d,D_new)
E = (E + E.permute(2, 1, 0))/2. E = (E + E.permute(2, 1, 0))/2.
D = D_new D = D_new
...@@ -44,15 +51,16 @@ def ctmrg(Tin, d, Dcut, no_iter): ...@@ -44,15 +51,16 @@ def ctmrg(Tin, d, Dcut, no_iter):
Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C)) Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C))
Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C)) Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C))
print (Z1.item(), Z2.item(), Z3.item()) print ('Z1, Z2, Z3:', Z1.item(), Z2.item(), Z3.item())
lnZ += torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs()) lnZ += torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs())
return lnZ, truncation_error return lnZ, truncation_error
if __name__=='__main__': if __name__=='__main__':
torch.set_num_threads(4)
K = torch.tensor([0.44]) K = torch.tensor([0.44])
Dcut = 20 Dcut = 180
n = 20 n = 50
#Boltzmann factor on a bond M=LR^T #Boltzmann factor on a bond M=LR^T
M = torch.stack([torch.cat([torch.exp(K), torch.exp(-K)]), M = torch.stack([torch.cat([torch.exp(K), torch.exp(-K)]),
......
...@@ -66,7 +66,7 @@ if __name__=='__main__': ...@@ -66,7 +66,7 @@ if __name__=='__main__':
lnZ, error2 = contraction(T2, D**2, Dcut, Niter) lnZ, error2 = contraction(T2, D**2, Dcut, Niter)
loss = (-lnT + lnZ) loss = (-lnT + lnZ)
print ('contraction done {:.3f}s'.format(time.time()-t0)) print ('contraction done {:.3f}s'.format(time.time()-t0))
print ('residual entropy', -loss.item(), error1.item(), error2.item()) print ('loss, error', loss.item(), error1.item(), error2.item())
t0=time.time() t0=time.time()
loss.backward() loss.backward()
...@@ -75,4 +75,4 @@ if __name__=='__main__': ...@@ -75,4 +75,4 @@ if __name__=='__main__':
for epoch in range(100): for epoch in range(100):
loss = optimizer.step(closure) loss = optimizer.step(closure)
#print ('epoch, loss', epoch, loss) print ('epoch, residual entropy', epoch, -loss.item())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment