Advanced Computing Platform for Theoretical Physics

commit大文件会使得服务器变得不稳定,请大家尽量只commit代码,不要commit大的文件。

Commit 2b604e71 authored by Lei Wang's avatar Lei Wang
Browse files

added adlib (before I was hacking pytorch source)

parent 2b3c0035
from .svd import SVD
from .eigh import EigenSolver
from .qr import QR
'''
PyTorch has its own implementation of backward function for symmetric eigensolver https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1660
However, it assumes a triangular adjoint.
We reimplement it to return a symmetric adjoint
'''
import numpy as np
import torch
class EigenSolver(torch.autograd.Function):
@staticmethod
def forward(self, A):
w, v = torch.symeig(A, eigenvectors=True)
self.save_for_backward(w, v)
return w, v
@staticmethod
def backward(self, dw, dv):
w, v = self.saved_tensors
dtype, device = w.dtype, w.device
N = v.shape[0]
F = w - w[:,None]
# in case of degenerated eigenvalues, replace the following two lines with a safe inverse
F.diagonal().fill_(np.inf);
F = 1./F
vt = v.t()
vdv = vt@dv
return v@(torch.diag(dw) + F*(vdv-vdv.t())/2) @vt
def test_eigs():
M = 2
torch.manual_seed(42)
A = torch.rand(M, M, dtype=torch.float64)
A = torch.nn.Parameter(A+A.t())
assert(torch.autograd.gradcheck(DominantEigensolver.apply, A, eps=1e-6, atol=1e-4))
print("Test Pass!")
if __name__=='__main__':
test_eigs()
import torch
from torch.utils.checkpoint import detach_variable
def step(A, x):
y = A@x
y = y[0].sign() * y
return y/y.norm()
class FixedPoint(torch.autograd.Function):
@staticmethod
def forward(ctx, A, x0, tol):
x, x_prev = step(A, x0), x0
while torch.dist(x, x_prev) > tol:
x, x_prev = step(A, x), x
ctx.save_for_backward(A, x)
ctx.tol = tol
return x
@staticmethod
def backward(ctx, grad):
A, x = detach_variable(ctx.saved_tensors)
dA = grad
while True:
with torch.enable_grad():
grad = torch.autograd.grad(step(A, x), x, grad_outputs=grad)[0]
if (torch.norm(grad) > ctx.tol):
dA = dA + grad
else:
break
with torch.enable_grad():
dA = torch.autograd.grad(step(A, x), A, grad_outputs=dA)[0]
return dA, None, None
def test_backward():
N = 4
torch.manual_seed(2)
A = torch.rand(N, N, dtype=torch.float64, requires_grad=True)
x0 = torch.rand(N, dtype=torch.float64)
x0 = x0/x0.norm()
tol = 1E-10
input = A, x0, tol
assert(torch.autograd.gradcheck(FixedPoint.apply, input, eps=1E-6, atol=tol))
print("Backward Test Pass!")
def test_forward():
torch.manual_seed(42)
N = 100
tol = 1E-8
dtype = torch.float64
A = torch.randn(N, N, dtype=dtype)
A = A+A.t()
w, v = torch.symeig(A, eigenvectors=True)
idx = torch.argmax(w.abs())
v_exact = v[:, idx]
v_exact = v_exact[0].sign() * v_exact
x0 = torch.rand(N, dtype=dtype)
x0 = x0/x0.norm()
x = FixedPoint.apply(A, x0, tol)
assert(torch.allclose(v_exact, x, rtol=tol, atol=tol))
print("Forward Test Pass!")
if __name__=='__main__':
test_forward()
test_backward()
import torch
class QR(torch.autograd.Function):
@staticmethod
def forward(self, A):
Q, R = torch.qr(A)
self.save_for_backward(A, Q, R)
return Q, R
@staticmethod
def backward(self, dq, dr):
A, q, r = self.saved_tensors
if r.shape[0] == r.shape[1]:
return _simple_qr_backward(q, r, dq ,dr)
M, N = r.shape
B = A[:,M:]
dU = dr[:,:M]
dD = dr[:,M:]
U = r[:,:M]
da = _simple_qr_backward(q, U, dq+B@dD.t(), dU)
db = q@dD
return torch.cat([da, db], 1)
def _simple_qr_backward(q, r, dq, dr):
if r.shape[-2] != r.shape[-1]:
raise NotImplementedError("QrGrad not implemented when ncols > nrows "
"or full_matrices is true and ncols != nrows.")
qdq = q.t() @ dq
qdq_ = qdq - qdq.t()
rdr = r @ dr.t()
rdr_ = rdr - rdr.t()
tril = torch.tril(qdq_ + rdr_)
def _TriangularSolve(x, r):
"""Equiv to x @ torch.inverse(r).t() if r is upper-tri."""
res = torch.trtrs(x.t(), r, upper=True, transpose=False)[0].t()
return res
grad_a = q @ (dr + _TriangularSolve(tril, r))
grad_b = _TriangularSolve(dq - q @ qdq, r)
return grad_a + grad_b
def test_qr():
M, N = 4, 6
torch.manual_seed(2)
A = torch.randn(M, N)
A.requires_grad=True
assert(torch.autograd.gradcheck(QR.apply, A, eps=1e-4, atol=1e-2))
print("Test Pass!")
if __name__ == "__main__":
test_qr()
'''
PyTorch has its own implementation of backward function for SVD https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1577
We reimplement it with a safe inverse function in light of degenerated singular values
'''
import numpy as np
import torch
def safe_inverse(x, epsilon=1E-12):
return x/(x**2 + epsilon)
class SVD(torch.autograd.Function):
@staticmethod
def forward(self, A):
U, S, V = torch.svd(A)
self.save_for_backward(U, S, V)
return U, S, V
@staticmethod
def backward(self, dU, dS, dV):
U, S, V = self.saved_tensors
Vt = V.t()
Ut = U.t()
M = U.size(0)
N = V.size(0)
NS = len(S)
F = (S - S[:, None])
F = safe_inverse(F)
F.diagonal().fill_(0)
G = (S + S[:, None])
G.diagonal().fill_(np.inf)
G = 1/G
UdU = Ut @ dU
VdV = Vt @ dV
Su = (F+G)*(UdU-UdU.t())/2
Sv = (F-G)*(VdV-VdV.t())/2
dA = U @ (Su + Sv + torch.diag(dS)) @ Vt
if (M>NS):
dA = dA + (torch.eye(M, dtype=dU.dtype, device=dU.device) - U@Ut) @ (dU/S) @ Vt
if (N>NS):
dA = dA + (U/S) @ dV.t() @ (torch.eye(N, dtype=dU.dtype, device=dU.device) - V@Vt)
return dA
def test_svd():
M, N = 50, 40
torch.manual_seed(2)
input = torch.rand(M, N, dtype=torch.float64, requires_grad=True)
assert(torch.autograd.gradcheck(SVD.apply, input, eps=1e-6, atol=1e-4))
print("Test Pass!")
if __name__=='__main__':
test_svd()
import torch import torch
from itertools import permutations from itertools import permutations
from adlib import SVD
svd = SVD.apply
def ctmrg(T, d, Dcut, max_iter): def ctmrg(T, d, Dcut, max_iter):
...@@ -24,7 +26,7 @@ def ctmrg(T, d, Dcut, max_iter): ...@@ -24,7 +26,7 @@ def ctmrg(T, d, Dcut, max_iter):
A = (A+A.t())/2. A = (A+A.t())/2.
D_new = min(D*d, Dcut) D_new = min(D*d, Dcut)
U, S, V = torch.svd(A) U, S, V = svd(A)
s = S/S.max() s = S/S.max()
truncation_error += S[D_new:].sum()/S.sum() truncation_error += S[D_new:].sum()/S.sum()
P = U[:, :D_new] # projection operator P = U[:, :D_new] # projection operator
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
torch.set_num_threads(1) torch.set_num_threads(1)
torch.manual_seed(42) torch.manual_seed(42)
from trg import levin_nave_trg as contraction #from trg import levin_nave_trg as contraction
from ctmrg import ctmrg as contraction from ctmrg import ctmrg as contraction
#from vmps import vmps as contraction #from vmps import vmps as contraction
...@@ -17,8 +17,8 @@ if __name__=='__main__': ...@@ -17,8 +17,8 @@ if __name__=='__main__':
import argparse import argparse
parser = argparse.ArgumentParser(description='') parser = argparse.ArgumentParser(description='')
parser.add_argument("-D", type=int, default=2, help="D") parser.add_argument("-D", type=int, default=2, help="D")
parser.add_argument("-Dcut", type=int, default=20, help="Dcut") parser.add_argument("-Dcut", type=int, default=30, help="Dcut")
parser.add_argument("-Niter", type=int, default=10, help="Niter") parser.add_argument("-Niter", type=int, default=20, help="Niter")
parser.add_argument("-float32", action='store_true', help="use float32") parser.add_argument("-float32", action='store_true', help="use float32")
parser.add_argument("-lanczos_steps", type=int, default=0, help="lanczos steps") parser.add_argument("-lanczos_steps", type=int, default=0, help="lanczos steps")
...@@ -44,8 +44,8 @@ if __name__=='__main__': ...@@ -44,8 +44,8 @@ if __name__=='__main__':
A = torch.nn.Parameter(B.view(d, D**4)) A = torch.nn.Parameter(B.view(d, D**4))
#boundary MPS #boundary MPS
A1 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2*d, Dcut, dtype=dtype, device=device)) #A1 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2*d, Dcut, dtype=dtype, device=device))
A2 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2, Dcut, dtype=dtype, device=device)) #A2 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2, Dcut, dtype=dtype, device=device))
#dimer covering #dimer covering
T = torch.zeros(d, d, d, d, d, d, dtype=dtype, device=device) T = torch.zeros(d, d, d, d, d, d, dtype=dtype, device=device)
......
import torch import torch
from adlib import SVD
svd = SVD.apply
def levin_nave_trg(T, D, Dcut, no_iter): def levin_nave_trg(T, D, Dcut, no_iter):
...@@ -15,8 +17,8 @@ def levin_nave_trg(T, D, Dcut, no_iter): ...@@ -15,8 +17,8 @@ def levin_nave_trg(T, D, Dcut, no_iter):
Ma = T.permute(2, 1, 0, 3).contiguous().view(D**2, D**2) Ma = T.permute(2, 1, 0, 3).contiguous().view(D**2, D**2)
Mb = T.permute(3, 2, 1, 0).contiguous().view(D**2, D**2) Mb = T.permute(3, 2, 1, 0).contiguous().view(D**2, D**2)
Ua, Sa, Va = torch.svd(Ma) Ua, Sa, Va = svd(Ma)
Ub, Sb, Vb = torch.svd(Mb) Ub, Sb, Vb = svd(Mb)
D_new = min(D**2, Dcut) D_new = min(D**2, Dcut)
truncation_error += Sa[D_new:].sum()/Sa.sum() + Sb[D_new:].sum()/Sb.sum() truncation_error += Sa[D_new:].sum()/Sa.sum() + Sb[D_new:].sum()/Sb.sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment