Advanced Computing Platform for Theoretical Physics

Commit b202c35b authored by Lei Wang's avatar Lei Wang
Browse files

move with_grad into forward

parent 8d8884fc
...@@ -4,7 +4,7 @@ from tensornets import CTMRG ...@@ -4,7 +4,7 @@ from tensornets import CTMRG
''' '''
customize the backward of ctmrg contraction customize the backward of ctmrg contraction
we simply provide the adjoint of T so it does not need to diff into C and E we simply provide the adjoint of T so it does not need to diff into C and E
however, since we have enable_grad in the backward, the double backward will enter C and E however, since we use C, E with enable_grad in the backward, the double backward will enter C and E
''' '''
class Contraction(torch.autograd.Function): class Contraction(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -15,21 +15,20 @@ class Contraction(torch.autograd.Function): ...@@ -15,21 +15,20 @@ class Contraction(torch.autograd.Function):
self.epsilon = epsilon self.epsilon = epsilon
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
C, E = CTMRG(T, chi, maxiter, epsilon, use_checkpoint=use_checkpoint) with torch.enable_grad():
C, E = CTMRG(T, chi, maxiter, epsilon, use_checkpoint=use_checkpoint)
self.save_for_backward(T, C, E)
Z1 = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C)) Z1 = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C))
Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C)) Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C))
Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C)) Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C))
self.save_for_backward(T)
return torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs()) return torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs())
@staticmethod @staticmethod
def backward(self, dlnZ): def backward(self, dlnZ):
T, = self.saved_tensors T, C, E = self.saved_tensors
with torch.enable_grad():
C, E = CTMRG(T, self.chi, self.maxiter, self.epsilon, use_checkpoint=self.use_checkpoint)
up = torch.einsum('ab,bcd,fd,gha,fjk,lg,mil,mk->chij', (C,E,C,E,E,C,E,C)) * dlnZ up = torch.einsum('ab,bcd,fd,gha,fjk,lg,mil,mk->chij', (C,E,C,E,E,C,E,C)) * dlnZ
dn = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C)) dn = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C))
return up/dn, None, None, None, None return up/dn, None, None, None, None
...@@ -55,6 +54,7 @@ if __name__=='__main__': ...@@ -55,6 +54,7 @@ if __name__=='__main__':
maxiter = 50 maxiter = 50
epsilon = 1E-10 epsilon = 1E-10
dtype = torch.float64 dtype = torch.float64
use_checkpoint = False
#dimer covering #dimer covering
T = torch.zeros(d, d, d, d, d, d, dtype=dtype) T = torch.zeros(d, d, d, d, d, d, dtype=dtype)
...@@ -72,13 +72,13 @@ if __name__=='__main__': ...@@ -72,13 +72,13 @@ if __name__=='__main__':
A = torch.as_tensor(x).view(d, D, D, D, D) A = torch.as_tensor(x).view(d, D, D, D, D)
As = symmetrize(A).view(d, D**4) As = symmetrize(A).view(d, D**4)
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2) T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
return -contract(T2, chi, maxiter, epsilon) return -contract(T2, chi, maxiter, epsilon, use_checkpoint)
def g(x): def g(x):
A = torch.as_tensor(x).view(d, D, D, D, D).requires_grad_() A = torch.as_tensor(x).view(d, D, D, D, D).requires_grad_()
As = symmetrize(A).view(d, D**4) As = symmetrize(A).view(d, D**4)
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2) T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
loss = -contract(T2, chi, maxiter, epsilon) loss = -contract(T2, chi, maxiter, epsilon, use_checkpoint)
loss.backward() loss.backward()
return (A.grad).numpy().ravel() return (A.grad).numpy().ravel()
......
...@@ -79,8 +79,6 @@ if __name__=='__main__': ...@@ -79,8 +79,6 @@ if __name__=='__main__':
info['loss'] = loss info['loss'] = loss
info['A'] = A info['A'] = A
print ('A.grad in fun', A.grad)
print (info['feval'], loss.item(), A.grad.norm().item()) print (info['feval'], loss.item(), A.grad.norm().item())
return loss.item(), A.grad.detach().cpu().numpy().ravel() return loss.item(), A.grad.detach().cpu().numpy().ravel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment