Advanced Computing Platform for Theoretical Physics

Commit 2e8201e3 authored by Lei Wang's avatar Lei Wang
Browse files

dhotrg

parent 24f0891f
from .svd import SVD
from .eigh import EigenSolver
from .qr import QR
'''
PyTorch has its own implementation of backward function for symmetric eigensolver https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1660
However, it assumes a triangular adjoint.
We reimplement it to return a symmetric adjoint
'''
import numpy as np
import torch
class EigenSolver(torch.autograd.Function):
@staticmethod
def forward(self, A):
w, v = torch.symeig(A, eigenvectors=True)
self.save_for_backward(w, v)
return w, v
@staticmethod
def backward(self, dw, dv):
w, v = self.saved_tensors
dtype, device = w.dtype, w.device
N = v.shape[0]
F = w - w[:,None]
# in case of degenerated eigenvalues, replace the following two lines with a safe inverse
F.diagonal().fill_(np.inf);
F = 1./F
vt = v.t()
vdv = vt@dv
return v@(torch.diag(dw) + F*(vdv-vdv.t())/2) @vt
def test_eigs():
M = 2
torch.manual_seed(42)
A = torch.rand(M, M, dtype=torch.float64)
A = torch.nn.Parameter(A+A.t())
assert(torch.autograd.gradcheck(DominantEigensolver.apply, A, eps=1e-6, atol=1e-4))
print("Test Pass!")
if __name__=='__main__':
test_eigs()
import torch
from torch.utils.checkpoint import detach_variable
def step(A, x):
y = A@x
y = y[0].sign() * y
return y/y.norm()
class FixedPoint(torch.autograd.Function):
@staticmethod
def forward(ctx, A, x0, tol):
x, x_prev = step(A, x0), x0
while torch.dist(x, x_prev) > tol:
x, x_prev = step(A, x), x
ctx.save_for_backward(A, x)
ctx.tol = tol
return x
@staticmethod
def backward(ctx, grad):
A, x = detach_variable(ctx.saved_tensors)
dA = grad
while True:
with torch.enable_grad():
grad = torch.autograd.grad(step(A, x), x, grad_outputs=grad)[0]
if (torch.norm(grad) > ctx.tol):
dA = dA + grad
else:
break
with torch.enable_grad():
dA = torch.autograd.grad(step(A, x), A, grad_outputs=dA)[0]
return dA, None, None
def test_backward():
N = 4
torch.manual_seed(2)
A = torch.rand(N, N, dtype=torch.float64, requires_grad=True)
x0 = torch.rand(N, dtype=torch.float64)
x0 = x0/x0.norm()
tol = 1E-10
input = A, x0, tol
assert(torch.autograd.gradcheck(FixedPoint.apply, input, eps=1E-6, atol=tol))
print("Backward Test Pass!")
def test_forward():
torch.manual_seed(42)
N = 100
tol = 1E-8
dtype = torch.float64
A = torch.randn(N, N, dtype=dtype)
A = A+A.t()
w, v = torch.symeig(A, eigenvectors=True)
idx = torch.argmax(w.abs())
v_exact = v[:, idx]
v_exact = v_exact[0].sign() * v_exact
x0 = torch.rand(N, dtype=dtype)
x0 = x0/x0.norm()
x = FixedPoint.apply(A, x0, tol)
assert(torch.allclose(v_exact, x, rtol=tol, atol=tol))
print("Forward Test Pass!")
if __name__=='__main__':
test_forward()
test_backward()
import torch
class QR(torch.autograd.Function):
@staticmethod
def forward(self, A):
Q, R = torch.qr(A)
self.save_for_backward(A, Q, R)
return Q, R
@staticmethod
def backward(self, dq, dr):
A, q, r = self.saved_tensors
if r.shape[0] == r.shape[1]:
return _simple_qr_backward(q, r, dq ,dr)
M, N = r.shape
B = A[:,M:]
dU = dr[:,:M]
dD = dr[:,M:]
U = r[:,:M]
da = _simple_qr_backward(q, U, dq+B@dD.t(), dU)
db = q@dD
return torch.cat([da, db], 1)
def _simple_qr_backward(q, r, dq, dr):
if r.shape[-2] != r.shape[-1]:
raise NotImplementedError("QrGrad not implemented when ncols > nrows "
"or full_matrices is true and ncols != nrows.")
qdq = q.t() @ dq
qdq_ = qdq - qdq.t()
rdr = r @ dr.t()
rdr_ = rdr - rdr.t()
tril = torch.tril(qdq_ + rdr_)
def _TriangularSolve(x, r):
"""Equiv to x @ torch.inverse(r).t() if r is upper-tri."""
res = torch.trtrs(x.t(), r, upper=True, transpose=False)[0].t()
return res
grad_a = q @ (dr + _TriangularSolve(tril, r))
grad_b = _TriangularSolve(dq - q @ qdq, r)
return grad_a + grad_b
def test_qr():
M, N = 4, 6
torch.manual_seed(2)
A = torch.randn(M, N)
A.requires_grad=True
assert(torch.autograd.gradcheck(QR.apply, A, eps=1e-4, atol=1e-2))
print("Test Pass!")
if __name__ == "__main__":
test_qr()
'''
PyTorch has its own implementation of backward function for SVD https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1577
We reimplement it with a safe inverse function in light of degenerated singular values
'''
import numpy as np
import torch
def safe_inverse(x, epsilon=1E-12):
return x/(x**2 + epsilon)
class SVD(torch.autograd.Function):
@staticmethod
def forward(self, A):
U, S, V = torch.svd(A)
self.save_for_backward(U, S, V)
return U, S, V
@staticmethod
def backward(self, dU, dS, dV):
U, S, V = self.saved_tensors
Vt = V.t()
Ut = U.t()
M = U.size(0)
N = V.size(0)
NS = len(S)
F = (S - S[:, None])
F = safe_inverse(F)
F.diagonal().fill_(0)
G = (S + S[:, None])
G.diagonal().fill_(np.inf)
G = 1/G
UdU = Ut @ dU
VdV = Vt @ dV
Su = (F+G)*(UdU-UdU.t())/2
Sv = (F-G)*(VdV-VdV.t())/2
dA = U @ (Su + Sv + torch.diag(dS)) @ Vt
if (M>NS):
dA = dA + (torch.eye(M, dtype=dU.dtype, device=dU.device) - U@Ut) @ (dU/S) @ Vt
if (N>NS):
dA = dA + (U/S) @ dV.t() @ (torch.eye(N, dtype=dU.dtype, device=dU.device) - V@Vt)
return dA
def test_svd():
M, N = 50, 40
torch.manual_seed(2)
input = torch.rand(M, N, dtype=torch.float64, requires_grad=True)
assert(torch.autograd.gradcheck(SVD.apply, input, eps=1e-6, atol=1e-4))
print("Test Pass!")
if __name__=='__main__':
test_svd()
import io
import numpy as np
import torch
torch.manual_seed(42)
torch.set_num_threads(1)
from utils import gauge, trace
def build_tensor(K, H):
c = torch.sqrt(torch.cosh(K))
s = torch.sqrt(torch.sinh(K))
Q = [c, s, c, -s]
Q = torch.stack(Q).view(2, 2)
M = [torch.exp(H), torch.exp(-H)]
M = torch.stack(M).view(2)
T = torch.einsum('a,ai,aj,ak,al->ijkl', (M, Q, Q, Q, Q))
return T
def renormalize(T, U):
T = torch.einsum('axob,amz,moyn,bnw->xwzy',(T, U, T, U))
#TU1 = torch.tensordot(T, U, dims=([0], [0]))
#TU2 = torch.tensordot(T, U, dims=([3], [1]))
#T = torch.tensordot(TU1, TU2, dims=([1,2,3], [1,3,0]))
#T = T.permute(0,3,1,2).contiguous()
return T
class Ising(torch.nn.Module):
def __init__(self, chi, niter, dtype=torch.float64, device='cpu', use_checkpoint=False):
super(Ising, self).__init__()
self.D = 2
self.chi = chi # cutoff
self.niter = niter
self.dtype = dtype
self.device = device
def forward(self, T, enable_grad=True):
'''
'''
lnZ = 0.0
for step in range(self.niter):
f = T.abs().max()
lnZ += 2**(-step)*torch.log(f)
T = T / f
torch.set_grad_enabled(enable_grad)
vl, el = gauge(T, self.chi,'l')
vr, er = gauge(T, self.chi,'r')
U = vl if (el < er) else vr
torch.set_grad_enabled(True)
T = renormalize(T, U)
lnZ += torch.log(trace(T))/2**args.niter
return lnZ
def main(beta, h=1E-12):
beta = torch.tensor(beta, dtype=dtype, device=device)
h = torch.tensor(h, dtype=dtype, device=device).requires_grad_()
model = Ising(args.chi, args.niter, dtype=dtype, device=device)
T = build_tensor(beta, beta*h)
lnZ = model.forward(T, args.enable_grad)
dlnZ, = torch.autograd.grad(lnZ, h) #, create_graph=True) # m = 1/beta * dlnZ / dh
#dlnZ2, = torch.autograd.grad(dlnZ, h)
return lnZ, dlnZ/beta
def scan_beta():
key = 'D' + str(args.chi) + \
'N' + str(args.niter)
if args.enable_grad:
key += '_disometry'
with io.open(key+'.log', 'a', buffering=1, newline='\n') as logfile:
for beta in np.linspace(0.4, 0.5, 51):
lnZ, m = main(beta)
message = (3*'{:.8f} ').format(beta, lnZ.item(), m.item())
print ('beta, lnZ, m', message)
logfile.write(message + u'\n')
if __name__=='__main__':
import argparse
parser = argparse.ArgumentParser(description='')
#parser.add_argument("-beta", type=float, default=0.44068679350977147, help="beta")
parser.add_argument("-chi", type=int, default=24, help="chi")
parser.add_argument("-niter", type=int, default=50, help="niter")
parser.add_argument("-float32", action='store_true', help="use float32")
parser.add_argument("-enable_grad", action='store_true', help="diff into isometry")
parser.add_argument("-cuda", type=int, default=-1, help="GPU #")
args = parser.parse_args()
device = torch.device("cpu" if args.cuda<0 else "cuda:"+str(args.cuda))
dtype = torch.float32 if args.float32 else torch.float64
print ('use', dtype)
#main(args.beta)
scan_beta()
import torch
from adlib import SVD
svd = SVD.apply
def gauge(T, Dcut, side):
D = T.shape[0]
......@@ -10,7 +12,7 @@ def gauge(T, Dcut, side):
M = M.view(-1, D**2)
M = M.t()@M
U, S, V = torch.svd(M)
U, S, V = svd(M)
Dnew = min(D**2, Dcut)
truncation = S[Dnew:].sum() / S.sum()
return U[:, :Dnew].view(D, D, Dnew), truncation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment