Advanced Computing Platform for Theoretical Physics

commit大文件会使得服务器变得不稳定，请大家尽量只commit代码，不要commit大的文件。

Commit 8d8884fc by Lei Wang

### wrap contraction into a primitive to have 1. efficient backward 2. correct double backward

parent 9c9dc1b8
 from .svd import SVD from .eigh import EigenSolver from .qr import QR
 ''' PyTorch has its own implementation of backward function for symmetric eigensolver https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1660 However, it assumes a triangular adjoint. We reimplement it to return a symmetric adjoint ''' import numpy as np import torch class EigenSolver(torch.autograd.Function): @staticmethod def forward(self, A): w, v = torch.symeig(A, eigenvectors=True) self.save_for_backward(w, v) return w, v @staticmethod def backward(self, dw, dv): w, v = self.saved_tensors dtype, device = w.dtype, w.device N = v.shape[0] F = w - w[:,None] # in case of degenerated eigenvalues, replace the following two lines with a safe inverse F.diagonal().fill_(np.inf); F = 1./F vt = v.t() vdv = vt@dv return v@(torch.diag(dw) + F*(vdv-vdv.t())/2) @vt def test_eigs(): M = 2 torch.manual_seed(42) A = torch.rand(M, M, dtype=torch.float64) A = torch.nn.Parameter(A+A.t()) assert(torch.autograd.gradcheck(DominantEigensolver.apply, A, eps=1e-6, atol=1e-4)) print("Test Pass!") if __name__=='__main__': test_eigs()
 ''' PyTorch has its own implementation of backward function for SVD https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1577 We reimplement it with a safe inverse function in light of degenerated singular values ''' import numpy as np import torch def safe_inverse(x, epsilon=1E-12): return x/(x**2 + epsilon) class SVD(torch.autograd.Function): @staticmethod def forward(self, A): U, S, V = torch.svd(A) self.save_for_backward(U, S, V) return U, S, V @staticmethod def backward(self, dU, dS, dV): U, S, V = self.saved_tensors Vt = V.t() Ut = U.t() M = U.size(0) N = V.size(0) NS = len(S) F = (S - S[:, None]) F = safe_inverse(F) F.diagonal().fill_(0) G = (S + S[:, None]) G.diagonal().fill_(np.inf) G = 1/G UdU = Ut @ dU VdV = Vt @ dV Su = (F+G)*(UdU-UdU.t())/2 Sv = (F-G)*(VdV-VdV.t())/2 dA = U @ (Su + Sv + torch.diag(dS)) @ Vt if (M>NS): dA = dA + (torch.eye(M, dtype=dU.dtype, device=dU.device) - U@Ut) @ (dU/S) @ Vt if (N>NS): dA = dA + (U/S) @ dV.t() @ (torch.eye(N, dtype=dU.dtype, device=dU.device) - V@Vt) return dA def test_svd(): M, N = 50, 40 torch.manual_seed(2) input = torch.rand(M, N, dtype=torch.float64, requires_grad=True) assert(torch.autograd.gradcheck(SVD.apply, input, eps=1e-6, atol=1e-4)) print("Test Pass!") if __name__=='__main__': test_svd()
ctmrg.py deleted 100644 → 0
 import torch from itertools import permutations from adlib import SVD svd = SVD.apply def build_tensor(K, H): c = torch.sqrt(torch.cosh(K)) s = torch.sqrt(torch.sinh(K)) Q = [c, s, c, -s] Q = torch.stack(Q).view(2, 2) M = [torch.exp(H), torch.exp(-H)] M = torch.stack(M).view(2) T = torch.einsum('a,ai,aj,ak,al->ijkl', (M, Q, Q, Q, Q)) return T def ctmrg(T, d, Dcut, max_iter): lnZ = 0.0 truncation_error = 0.0 #C = torch.randn(d, d, dtype=T.dtype, device=T.device) #T.sum((0,1)) #E = torch.randn(d, d, d, dtype=T.dtype, device=T.device)#T.sum(1) C = T.sum((0,1)) E = T.sum(1) D = d sold = torch.zeros(d, dtype=T.dtype, device=T.device) diff = 1E1 for n in range(max_iter): A = torch.einsum('ab,eca,bdg,cdfh->efgh', (C, E, E, T)).contiguous().view(D*d, D*d) A = (A+A.t())/2. D_new = min(D*d, Dcut) U, S, V = svd(A) s = S/S.max() truncation_error += S[D_new:].sum()/S.sum() P = U[:, :D_new] # projection operator #S, U = torch.symeig(A, eigenvectors=True) #truncation_error += 1.-S[-D_new:].sum()/S.sum() #P = U[:, -D_new:] # projection operator C = (P.t() @ A @ P) #(D, D) C = (C+C.t())/2. ET = torch.einsum('ldr,adbc->labrc', (E, T)).contiguous().view(D*d, d, D*d) #ET = torch.tensordot(E, T, dims=([1], [1])) #ET = ET.permute(0, 2, 3, 1, 4).contiguous().view(D*d, d, D*d) E = torch.einsum('li,ldr,rj->idj', (P, ET, P)) #(D_new, d, D_new) #E = ( P.t() @ ((ET.view(D*d*d, D*d)@P).view(D*d, d*D_new))).view(D_new,d,D_new) E = (E + E.permute(2, 1, 0))/2. D = D_new C = C/C.norm() E = E/E.norm() if (s.numel() == sold.numel()): diff = (s-sold).norm() if (diff < 1E-8): break sold = s #print ('ctmrg iterations', n) #C = C.detach() #E = E.detach() Z1 = torch.einsum('ab,bcd,fd,gha,hcij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C)) #CEC = torch.einsum('da,ebd,ce->abc', (C,E,C)).view(1, D**2*d) #ETE = torch.einsum('abc,lbdr,mdn->almcrn',(E,T,E)).contiguous().view(D**2*d, D**2*d) #ETE = (ETE+ETE.t())/2. #Z1 = CEC@ETE@CEC.t() Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C)) Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C)) #print (' Z1, Z2, Z3:', Z1.item(), Z2.item(), Z3.item()) lnZ += torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs()) return lnZ, truncation_error/n if __name__=='__main__': torch.set_num_threads(1) beta = 0.44 J = 1.0 h = 0.0 beta = torch.tensor(beta, dtype=torch.float64).requires_grad_() h = torch.tensor(h, dtype=torch.float64).requires_grad_() Dcut = 100 max_iter = 1000 T = build_tensor(beta*J, beta*h) lnZ, error = ctmrg(T, 2, Dcut, max_iter) print (lnZ.item(), error.item()) dlnZ, = torch.autograd.grad(lnZ, h, create_graph=True) dlnZ2, = torch.autograd.grad(dlnZ, h, create_graph=True) print (dlnZ.item(), dlnZ2.item())
 from .ctmrg import CTMRG from .singlelayer import projection, ctmrg, vumps from .utils import save_checkpoint, load_checkpoint, kronecker_product, symmetrize from .ctmrg import CTMRG, renormalize
 ... ... @@ -84,25 +84,4 @@ def CTMRG(T, chi, max_iter, epsilon, use_checkpoint=False): return C, E if __name__=='__main__': import time torch.manual_seed(42) D = 6 chi = 80 max_iter = 100 device = 'cpu' # T(u,l,d,r) T = torch.randn(D, D, D, D, dtype=torch.float64, device=device, requires_grad=True) T = (T + T.permute(0, 3, 2, 1))/2. # left-right symmetry T = (T + T.permute(2, 1, 0, 3))/2. # up-down symmetry T = (T + T.permute(3, 2, 1, 0))/2. # skew-diagonal symmetry T = (T + T.permute(1, 0, 3, 2))/2. # digonal symmetry T = T/T.norm() C, E = CTMRG(T, chi, max_iter, use_checkpoint=True) C, E = CTMRG(T, chi, max_iter, use_checkpoint=False) print( 'diffC = ', torch.dist( C, C.t() ) ) print( 'diffE = ', torch.dist( E, E.permute(2,1,0) ) )
trg.py deleted 100644 → 0
 import torch from adlib import SVD svd = SVD.apply def levin_nave_trg(T, D, Dcut, no_iter): lnZ = 0.0 truncation_error = 0.0 for n in range(no_iter): #print(n, " ", T.max(), " ", T.min()) maxval = T.max() T = T/maxval lnZ += 2**(no_iter-n)*torch.log(maxval) #print (2**(no_iter-n)*torch.log(maxval)/2**(no_iter)) Ma = T.permute(2, 1, 0, 3).contiguous().view(D**2, D**2) Mb = T.permute(3, 2, 1, 0).contiguous().view(D**2, D**2)