Advanced Computing Platform for Theoretical Physics


Commit 2b604e71 authored by Lei Wang's avatar Lei Wang
Browse files

added adlib (before I was hacking pytorch source)

parent 2b3c0035
from .svd import SVD
from .eigh import EigenSolver
from .qr import QR
PyTorch has its own implementation of backward function for symmetric eigensolver
However, it assumes a triangular adjoint.
We reimplement it to return a symmetric adjoint
import numpy as np
import torch
class EigenSolver(torch.autograd.Function):
def forward(self, A):
w, v = torch.symeig(A, eigenvectors=True)
self.save_for_backward(w, v)
return w, v
def backward(self, dw, dv):
w, v = self.saved_tensors
dtype, device = w.dtype, w.device
N = v.shape[0]
F = w - w[:,None]
# in case of degenerated eigenvalues, replace the following two lines with a safe inverse
F = 1./F
vt = v.t()
vdv = vt@dv
return v@(torch.diag(dw) + F*(vdv-vdv.t())/2) @vt
def test_eigs():
M = 2
A = torch.rand(M, M, dtype=torch.float64)
A = torch.nn.Parameter(A+A.t())
assert(torch.autograd.gradcheck(DominantEigensolver.apply, A, eps=1e-6, atol=1e-4))
print("Test Pass!")
if __name__=='__main__':
import torch
from torch.utils.checkpoint import detach_variable
def step(A, x):
y = A@x
y = y[0].sign() * y
return y/y.norm()
class FixedPoint(torch.autograd.Function):
def forward(ctx, A, x0, tol):
x, x_prev = step(A, x0), x0
while torch.dist(x, x_prev) > tol:
x, x_prev = step(A, x), x
ctx.save_for_backward(A, x)
ctx.tol = tol
return x
def backward(ctx, grad):
A, x = detach_variable(ctx.saved_tensors)
dA = grad
while True:
with torch.enable_grad():
grad = torch.autograd.grad(step(A, x), x, grad_outputs=grad)[0]
if (torch.norm(grad) > ctx.tol):
dA = dA + grad
with torch.enable_grad():
dA = torch.autograd.grad(step(A, x), A, grad_outputs=dA)[0]
return dA, None, None
def test_backward():
N = 4
A = torch.rand(N, N, dtype=torch.float64, requires_grad=True)
x0 = torch.rand(N, dtype=torch.float64)
x0 = x0/x0.norm()
tol = 1E-10
input = A, x0, tol
assert(torch.autograd.gradcheck(FixedPoint.apply, input, eps=1E-6, atol=tol))
print("Backward Test Pass!")
def test_forward():
N = 100
tol = 1E-8
dtype = torch.float64
A = torch.randn(N, N, dtype=dtype)
A = A+A.t()
w, v = torch.symeig(A, eigenvectors=True)
idx = torch.argmax(w.abs())
v_exact = v[:, idx]
v_exact = v_exact[0].sign() * v_exact
x0 = torch.rand(N, dtype=dtype)
x0 = x0/x0.norm()
x = FixedPoint.apply(A, x0, tol)
assert(torch.allclose(v_exact, x, rtol=tol, atol=tol))
print("Forward Test Pass!")
if __name__=='__main__':
import torch
class QR(torch.autograd.Function):
def forward(self, A):
Q, R = torch.qr(A)
self.save_for_backward(A, Q, R)
return Q, R
def backward(self, dq, dr):
A, q, r = self.saved_tensors
if r.shape[0] == r.shape[1]:
return _simple_qr_backward(q, r, dq ,dr)
M, N = r.shape
B = A[:,M:]
dU = dr[:,:M]
dD = dr[:,M:]
U = r[:,:M]
da = _simple_qr_backward(q, U, dq+B@dD.t(), dU)
db = q@dD
return[da, db], 1)
def _simple_qr_backward(q, r, dq, dr):
if r.shape[-2] != r.shape[-1]:
raise NotImplementedError("QrGrad not implemented when ncols > nrows "
"or full_matrices is true and ncols != nrows.")
qdq = q.t() @ dq
qdq_ = qdq - qdq.t()
rdr = r @ dr.t()
rdr_ = rdr - rdr.t()
tril = torch.tril(qdq_ + rdr_)
def _TriangularSolve(x, r):
"""Equiv to x @ torch.inverse(r).t() if r is upper-tri."""
res = torch.trtrs(x.t(), r, upper=True, transpose=False)[0].t()
return res
grad_a = q @ (dr + _TriangularSolve(tril, r))
grad_b = _TriangularSolve(dq - q @ qdq, r)
return grad_a + grad_b
def test_qr():
M, N = 4, 6
A = torch.randn(M, N)
assert(torch.autograd.gradcheck(QR.apply, A, eps=1e-4, atol=1e-2))
print("Test Pass!")
if __name__ == "__main__":
PyTorch has its own implementation of backward function for SVD
We reimplement it with a safe inverse function in light of degenerated singular values
import numpy as np
import torch
def safe_inverse(x, epsilon=1E-12):
return x/(x**2 + epsilon)
class SVD(torch.autograd.Function):
def forward(self, A):
U, S, V = torch.svd(A)
self.save_for_backward(U, S, V)
return U, S, V
def backward(self, dU, dS, dV):
U, S, V = self.saved_tensors
Vt = V.t()
Ut = U.t()
M = U.size(0)
N = V.size(0)
NS = len(S)
F = (S - S[:, None])
F = safe_inverse(F)
G = (S + S[:, None])
G = 1/G
UdU = Ut @ dU
VdV = Vt @ dV
Su = (F+G)*(UdU-UdU.t())/2
Sv = (F-G)*(VdV-VdV.t())/2
dA = U @ (Su + Sv + torch.diag(dS)) @ Vt
if (M>NS):
dA = dA + (torch.eye(M, dtype=dU.dtype, device=dU.device) - U@Ut) @ (dU/S) @ Vt
if (N>NS):
dA = dA + (U/S) @ dV.t() @ (torch.eye(N, dtype=dU.dtype, device=dU.device) - V@Vt)
return dA
def test_svd():
M, N = 50, 40
input = torch.rand(M, N, dtype=torch.float64, requires_grad=True)
assert(torch.autograd.gradcheck(SVD.apply, input, eps=1e-6, atol=1e-4))
print("Test Pass!")
if __name__=='__main__':
import torch
from itertools import permutations
from adlib import SVD
svd = SVD.apply
def ctmrg(T, d, Dcut, max_iter):
......@@ -24,7 +26,7 @@ def ctmrg(T, d, Dcut, max_iter):
A = (A+A.t())/2.
D_new = min(D*d, Dcut)
U, S, V = torch.svd(A)
U, S, V = svd(A)
s = S/S.max()
truncation_error += S[D_new:].sum()/S.sum()
P = U[:, :D_new] # projection operator
......@@ -8,7 +8,7 @@ import torch
from trg import levin_nave_trg as contraction
#from trg import levin_nave_trg as contraction
from ctmrg import ctmrg as contraction
#from vmps import vmps as contraction
......@@ -17,8 +17,8 @@ if __name__=='__main__':
import argparse
parser = argparse.ArgumentParser(description='')
parser.add_argument("-D", type=int, default=2, help="D")
parser.add_argument("-Dcut", type=int, default=20, help="Dcut")
parser.add_argument("-Niter", type=int, default=10, help="Niter")
parser.add_argument("-Dcut", type=int, default=30, help="Dcut")
parser.add_argument("-Niter", type=int, default=20, help="Niter")
parser.add_argument("-float32", action='store_true', help="use float32")
parser.add_argument("-lanczos_steps", type=int, default=0, help="lanczos steps")
......@@ -44,8 +44,8 @@ if __name__=='__main__':
A = torch.nn.Parameter(B.view(d, D**4))
#boundary MPS
A1 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2*d, Dcut, dtype=dtype, device=device))
A2 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2, Dcut, dtype=dtype, device=device))
#A1 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2*d, Dcut, dtype=dtype, device=device))
#A2 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2, Dcut, dtype=dtype, device=device))
#dimer covering
T = torch.zeros(d, d, d, d, d, d, dtype=dtype, device=device)
import torch
from adlib import SVD
svd = SVD.apply
def levin_nave_trg(T, D, Dcut, no_iter):
......@@ -15,8 +17,8 @@ def levin_nave_trg(T, D, Dcut, no_iter):
Ma = T.permute(2, 1, 0, 3).contiguous().view(D**2, D**2)
Mb = T.permute(3, 2, 1, 0).contiguous().view(D**2, D**2)
Ua, Sa, Va = torch.svd(Ma)
Ub, Sb, Vb = torch.svd(Mb)
Ua, Sa, Va = svd(Ma)
Ub, Sb, Vb = svd(Mb)
D_new = min(D**2, Dcut)
truncation_error += Sa[D_new:].sum()/Sa.sum() + Sb[D_new:].sum()/Sb.sum()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment