Advanced Computing Platform for Theoretical Physics

commit大文件会使得服务器变得不稳定,请大家尽量只commit代码,不要commit大的文件。

Commit b202c35b authored by Lei Wang's avatar Lei Wang
Browse files

move with_grad into forward

parent 8d8884fc
......@@ -4,7 +4,7 @@ from tensornets import CTMRG
'''
customize the backward of ctmrg contraction
we simply provide the adjoint of T so it does not need to diff into C and E
however, since we have enable_grad in the backward, the double backward will enter C and E
however, since we use C, E with enable_grad in the backward, the double backward will enter C and E
'''
class Contraction(torch.autograd.Function):
@staticmethod
......@@ -15,21 +15,20 @@ class Contraction(torch.autograd.Function):
self.epsilon = epsilon
self.use_checkpoint = use_checkpoint
C, E = CTMRG(T, chi, maxiter, epsilon, use_checkpoint=use_checkpoint)
with torch.enable_grad():
C, E = CTMRG(T, chi, maxiter, epsilon, use_checkpoint=use_checkpoint)
self.save_for_backward(T, C, E)
Z1 = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C))
Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C))
Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C))
self.save_for_backward(T)
return torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs())
@staticmethod
def backward(self, dlnZ):
T, = self.saved_tensors
with torch.enable_grad():
C, E = CTMRG(T, self.chi, self.maxiter, self.epsilon, use_checkpoint=self.use_checkpoint)
T, C, E = self.saved_tensors
up = torch.einsum('ab,bcd,fd,gha,fjk,lg,mil,mk->chij', (C,E,C,E,E,C,E,C)) * dlnZ
dn = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C))
return up/dn, None, None, None, None
......@@ -55,6 +54,7 @@ if __name__=='__main__':
maxiter = 50
epsilon = 1E-10
dtype = torch.float64
use_checkpoint = False
#dimer covering
T = torch.zeros(d, d, d, d, d, d, dtype=dtype)
......@@ -72,13 +72,13 @@ if __name__=='__main__':
A = torch.as_tensor(x).view(d, D, D, D, D)
As = symmetrize(A).view(d, D**4)
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
return -contract(T2, chi, maxiter, epsilon)
return -contract(T2, chi, maxiter, epsilon, use_checkpoint)
def g(x):
A = torch.as_tensor(x).view(d, D, D, D, D).requires_grad_()
As = symmetrize(A).view(d, D**4)
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
loss = -contract(T2, chi, maxiter, epsilon)
loss = -contract(T2, chi, maxiter, epsilon, use_checkpoint)
loss.backward()
return (A.grad).numpy().ravel()
......
......@@ -79,8 +79,6 @@ if __name__=='__main__':
info['loss'] = loss
info['A'] = A
print ('A.grad in fun', A.grad)
print (info['feval'], loss.item(), A.grad.norm().item())
return loss.item(), A.grad.detach().cpu().numpy().ravel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment