Advanced Computing Platform for Theoretical Physics

commit大文件会使得服务器变得不稳定，请大家尽量只commit代码，不要commit大的文件。

### added adlib (before I was hacking pytorch source)

parent 2b3c0035
adlib/__init__.py 0 → 100644
 from .svd import SVD from .eigh import EigenSolver from .qr import QR
adlib/eigh.py 0 → 100644
 ''' PyTorch has its own implementation of backward function for symmetric eigensolver https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1660 However, it assumes a triangular adjoint. We reimplement it to return a symmetric adjoint ''' import numpy as np import torch class EigenSolver(torch.autograd.Function): @staticmethod def forward(self, A): w, v = torch.symeig(A, eigenvectors=True) self.save_for_backward(w, v) return w, v @staticmethod def backward(self, dw, dv): w, v = self.saved_tensors dtype, device = w.dtype, w.device N = v.shape F = w - w[:,None] # in case of degenerated eigenvalues, replace the following two lines with a safe inverse F.diagonal().fill_(np.inf); F = 1./F vt = v.t() vdv = vt@dv return v@(torch.diag(dw) + F*(vdv-vdv.t())/2) @vt def test_eigs(): M = 2 torch.manual_seed(42) A = torch.rand(M, M, dtype=torch.float64) A = torch.nn.Parameter(A+A.t()) assert(torch.autograd.gradcheck(DominantEigensolver.apply, A, eps=1e-6, atol=1e-4)) print("Test Pass!") if __name__=='__main__': test_eigs()
adlib/power.py 0 → 100644
 import torch from torch.utils.checkpoint import detach_variable def step(A, x): y = A@x y = y.sign() * y return y/y.norm() class FixedPoint(torch.autograd.Function): @staticmethod def forward(ctx, A, x0, tol): x, x_prev = step(A, x0), x0 while torch.dist(x, x_prev) > tol: x, x_prev = step(A, x), x ctx.save_for_backward(A, x) ctx.tol = tol return x @staticmethod def backward(ctx, grad): A, x = detach_variable(ctx.saved_tensors) dA = grad while True: with torch.enable_grad(): grad = torch.autograd.grad(step(A, x), x, grad_outputs=grad) if (torch.norm(grad) > ctx.tol): dA = dA + grad else: break with torch.enable_grad(): dA = torch.autograd.grad(step(A, x), A, grad_outputs=dA) return dA, None, None def test_backward(): N = 4 torch.manual_seed(2) A = torch.rand(N, N, dtype=torch.float64, requires_grad=True) x0 = torch.rand(N, dtype=torch.float64) x0 = x0/x0.norm() tol = 1E-10 input = A, x0, tol assert(torch.autograd.gradcheck(FixedPoint.apply, input, eps=1E-6, atol=tol)) print("Backward Test Pass!") def test_forward(): torch.manual_seed(42) N = 100 tol = 1E-8 dtype = torch.float64 A = torch.randn(N, N, dtype=dtype) A = A+A.t() w, v = torch.symeig(A, eigenvectors=True) idx = torch.argmax(w.abs()) v_exact = v[:, idx] v_exact = v_exact.sign() * v_exact x0 = torch.rand(N, dtype=dtype) x0 = x0/x0.norm() x = FixedPoint.apply(A, x0, tol) assert(torch.allclose(v_exact, x, rtol=tol, atol=tol)) print("Forward Test Pass!") if __name__=='__main__': test_forward() test_backward()
adlib/qr.py 0 → 100644
 import torch class QR(torch.autograd.Function): @staticmethod def forward(self, A): Q, R = torch.qr(A) self.save_for_backward(A, Q, R) return Q, R @staticmethod def backward(self, dq, dr): A, q, r = self.saved_tensors if r.shape == r.shape: return _simple_qr_backward(q, r, dq ,dr) M, N = r.shape B = A[:,M:] dU = dr[:,:M] dD = dr[:,M:] U = r[:,:M] da = _simple_qr_backward(q, U, dq+B@dD.t(), dU) db = q@dD return torch.cat([da, db], 1) def _simple_qr_backward(q, r, dq, dr): if r.shape[-2] != r.shape[-1]: raise NotImplementedError("QrGrad not implemented when ncols > nrows " "or full_matrices is true and ncols != nrows.") qdq = q.t() @ dq qdq_ = qdq - qdq.t() rdr = r @ dr.t() rdr_ = rdr - rdr.t() tril = torch.tril(qdq_ + rdr_) def _TriangularSolve(x, r): """Equiv to x @ torch.inverse(r).t() if r is upper-tri.""" res = torch.trtrs(x.t(), r, upper=True, transpose=False).t() return res grad_a = q @ (dr + _TriangularSolve(tril, r)) grad_b = _TriangularSolve(dq - q @ qdq, r) return grad_a + grad_b def test_qr(): M, N = 4, 6 torch.manual_seed(2) A = torch.randn(M, N) A.requires_grad=True assert(torch.autograd.gradcheck(QR.apply, A, eps=1e-4, atol=1e-2)) print("Test Pass!") if __name__ == "__main__": test_qr()
adlib/svd.py 0 → 100644
 ''' PyTorch has its own implementation of backward function for SVD https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1577 We reimplement it with a safe inverse function in light of degenerated singular values ''' import numpy as np import torch def safe_inverse(x, epsilon=1E-12): return x/(x**2 + epsilon) class SVD(torch.autograd.Function): @staticmethod def forward(self, A): U, S, V = torch.svd(A) self.save_for_backward(U, S, V) return U, S, V @staticmethod def backward(self, dU, dS, dV): U, S, V = self.saved_tensors Vt = V.t() Ut = U.t() M = U.size(0) N = V.size(0) NS = len(S) F = (S - S[:, None]) F = safe_inverse(F) F.diagonal().fill_(0) G = (S + S[:, None]) G.diagonal().fill_(np.inf) G = 1/G UdU = Ut @ dU VdV = Vt @ dV Su = (F+G)*(UdU-UdU.t())/2 Sv = (F-G)*(VdV-VdV.t())/2 dA = U @ (Su + Sv + torch.diag(dS)) @ Vt if (M>NS): dA = dA + (torch.eye(M, dtype=dU.dtype, device=dU.device) - U@Ut) @ (dU/S) @ Vt if (N>NS): dA = dA + (U/S) @ dV.t() @ (torch.eye(N, dtype=dU.dtype, device=dU.device) - V@Vt) return dA def test_svd(): M, N = 50, 40 torch.manual_seed(2) input = torch.rand(M, N, dtype=torch.float64, requires_grad=True) assert(torch.autograd.gradcheck(SVD.apply, input, eps=1e-6, atol=1e-4)) print("Test Pass!") if __name__=='__main__': test_svd()
 import torch import torch from itertools import permutations from itertools import permutations from adlib import SVD svd = SVD.apply def ctmrg(T, d, Dcut, max_iter): def ctmrg(T, d, Dcut, max_iter): ... @@ -24,7 +26,7 @@ def ctmrg(T, d, Dcut, max_iter): ... @@ -24,7 +26,7 @@ def ctmrg(T, d, Dcut, max_iter): A = (A+A.t())/2. A = (A+A.t())/2. D_new = min(D*d, Dcut) D_new = min(D*d, Dcut) U, S, V = torch.svd(A) U, S, V = svd(A) s = S/S.max() s = S/S.max() truncation_error += S[D_new:].sum()/S.sum() truncation_error += S[D_new:].sum()/S.sum() P = U[:, :D_new] # projection operator P = U[:, :D_new] # projection operator ... ...
 ... @@ -8,7 +8,7 @@ import torch ... @@ -8,7 +8,7 @@ import torch torch.set_num_threads(1) torch.set_num_threads(1) torch.manual_seed(42) torch.manual_seed(42) from trg import levin_nave_trg as contraction #from trg import levin_nave_trg as contraction from ctmrg import ctmrg as contraction from ctmrg import ctmrg as contraction #from vmps import vmps as contraction #from vmps import vmps as contraction ... @@ -17,8 +17,8 @@ if __name__=='__main__': ... @@ -17,8 +17,8 @@ if __name__=='__main__': import argparse import argparse parser = argparse.ArgumentParser(description='') parser = argparse.ArgumentParser(description='') parser.add_argument("-D", type=int, default=2, help="D") parser.add_argument("-D", type=int, default=2, help="D") parser.add_argument("-Dcut", type=int, default=20, help="Dcut") parser.add_argument("-Dcut", type=int, default=30, help="Dcut") parser.add_argument("-Niter", type=int, default=10, help="Niter") parser.add_argument("-Niter", type=int, default=20, help="Niter") parser.add_argument("-float32", action='store_true', help="use float32") parser.add_argument("-float32", action='store_true', help="use float32") parser.add_argument("-lanczos_steps", type=int, default=0, help="lanczos steps") parser.add_argument("-lanczos_steps", type=int, default=0, help="lanczos steps") ... @@ -44,8 +44,8 @@ if __name__=='__main__': ... @@ -44,8 +44,8 @@ if __name__=='__main__': A = torch.nn.Parameter(B.view(d, D**4)) A = torch.nn.Parameter(B.view(d, D**4)) #boundary MPS #boundary MPS A1 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2*d, Dcut, dtype=dtype, device=device)) #A1 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2*d, Dcut, dtype=dtype, device=device)) A2 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2, Dcut, dtype=dtype, device=device)) #A2 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2, Dcut, dtype=dtype, device=device)) #dimer covering #dimer covering T = torch.zeros(d, d, d, d, d, d, dtype=dtype, device=device) T = torch.zeros(d, d, d, d, d, d, dtype=dtype, device=device) ... ...
 import torch import torch from adlib import SVD svd = SVD.apply def levin_nave_trg(T, D, Dcut, no_iter): def levin_nave_trg(T, D, Dcut, no_iter): ... @@ -15,8 +17,8 @@ def levin_nave_trg(T, D, Dcut, no_iter): ... @@ -15,8 +17,8 @@ def levin_nave_trg(T, D, Dcut, no_iter): Ma = T.permute(2, 1, 0, 3).contiguous().view(D**2, D**2) Ma = T.permute(2, 1, 0, 3).contiguous().view(D**2, D**2) Mb = T.permute(3, 2, 1, 0).contiguous().view(D**2, D**2) Mb = T.permute(3, 2, 1, 0).contiguous().view(D**2, D**2) Ua, Sa, Va = torch.svd(Ma) Ua, Sa, Va = svd(Ma) Ub, Sb, Vb = torch.svd(Mb) Ub, Sb, Vb = svd(Mb) D_new = min(D**2, Dcut) D_new = min(D**2, Dcut) truncation_error += Sa[D_new:].sum()/Sa.sum() + Sb[D_new:].sum()/Sb.sum() truncation_error += Sa[D_new:].sum()/Sa.sum() + Sb[D_new:].sum()/Sb.sum() ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment