Advanced Computing Platform for Theoretical Physics

Commit 8d8884fc authored by Lei Wang's avatar Lei Wang
Browse files

wrap contraction into a primitive to have 1. efficient backward 2. correct double backward

parent 9c9dc1b8
from .svd import SVD
from .eigh import EigenSolver
from .qr import QR
'''
PyTorch has its own implementation of backward function for symmetric eigensolver https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1660
However, it assumes a triangular adjoint.
We reimplement it to return a symmetric adjoint
'''
import numpy as np
import torch
class EigenSolver(torch.autograd.Function):
@staticmethod
def forward(self, A):
w, v = torch.symeig(A, eigenvectors=True)
self.save_for_backward(w, v)
return w, v
@staticmethod
def backward(self, dw, dv):
w, v = self.saved_tensors
dtype, device = w.dtype, w.device
N = v.shape[0]
F = w - w[:,None]
# in case of degenerated eigenvalues, replace the following two lines with a safe inverse
F.diagonal().fill_(np.inf);
F = 1./F
vt = v.t()
vdv = vt@dv
return v@(torch.diag(dw) + F*(vdv-vdv.t())/2) @vt
def test_eigs():
M = 2
torch.manual_seed(42)
A = torch.rand(M, M, dtype=torch.float64)
A = torch.nn.Parameter(A+A.t())
assert(torch.autograd.gradcheck(DominantEigensolver.apply, A, eps=1e-6, atol=1e-4))
print("Test Pass!")
if __name__=='__main__':
test_eigs()
import torch
from torch.utils.checkpoint import detach_variable
def step(A, x):
y = A@x
y = y[0].sign() * y
return y/y.norm()
class FixedPoint(torch.autograd.Function):
@staticmethod
def forward(ctx, A, x0, tol):
x, x_prev = step(A, x0), x0
while torch.dist(x, x_prev) > tol:
x, x_prev = step(A, x), x
ctx.save_for_backward(A, x)
ctx.tol = tol
return x
@staticmethod
def backward(ctx, grad):
A, x = detach_variable(ctx.saved_tensors)
dA = grad
while True:
with torch.enable_grad():
grad = torch.autograd.grad(step(A, x), x, grad_outputs=grad)[0]
if (torch.norm(grad) > ctx.tol):
dA = dA + grad
else:
break
with torch.enable_grad():
dA = torch.autograd.grad(step(A, x), A, grad_outputs=dA)[0]
return dA, None, None
def test_backward():
N = 4
torch.manual_seed(2)
A = torch.rand(N, N, dtype=torch.float64, requires_grad=True)
x0 = torch.rand(N, dtype=torch.float64)
x0 = x0/x0.norm()
tol = 1E-10
input = A, x0, tol
assert(torch.autograd.gradcheck(FixedPoint.apply, input, eps=1E-6, atol=tol))
print("Backward Test Pass!")
def test_forward():
torch.manual_seed(42)
N = 100
tol = 1E-8
dtype = torch.float64
A = torch.randn(N, N, dtype=dtype)
A = A+A.t()
w, v = torch.symeig(A, eigenvectors=True)
idx = torch.argmax(w.abs())
v_exact = v[:, idx]
v_exact = v_exact[0].sign() * v_exact
x0 = torch.rand(N, dtype=dtype)
x0 = x0/x0.norm()
x = FixedPoint.apply(A, x0, tol)
assert(torch.allclose(v_exact, x, rtol=tol, atol=tol))
print("Forward Test Pass!")
if __name__=='__main__':
test_forward()
test_backward()
import torch
class QR(torch.autograd.Function):
@staticmethod
def forward(self, A):
Q, R = torch.qr(A)
self.save_for_backward(A, Q, R)
return Q, R
@staticmethod
def backward(self, dq, dr):
A, q, r = self.saved_tensors
if r.shape[0] == r.shape[1]:
return _simple_qr_backward(q, r, dq ,dr)
M, N = r.shape
B = A[:,M:]
dU = dr[:,:M]
dD = dr[:,M:]
U = r[:,:M]
da = _simple_qr_backward(q, U, dq+B@dD.t(), dU)
db = q@dD
return torch.cat([da, db], 1)
def _simple_qr_backward(q, r, dq, dr):
if r.shape[-2] != r.shape[-1]:
raise NotImplementedError("QrGrad not implemented when ncols > nrows "
"or full_matrices is true and ncols != nrows.")
qdq = q.t() @ dq
qdq_ = qdq - qdq.t()
rdr = r @ dr.t()
rdr_ = rdr - rdr.t()
tril = torch.tril(qdq_ + rdr_)
def _TriangularSolve(x, r):
"""Equiv to x @ torch.inverse(r).t() if r is upper-tri."""
res = torch.trtrs(x.t(), r, upper=True, transpose=False)[0].t()
return res
grad_a = q @ (dr + _TriangularSolve(tril, r))
grad_b = _TriangularSolve(dq - q @ qdq, r)
return grad_a + grad_b
def test_qr():
M, N = 4, 6
torch.manual_seed(2)
A = torch.randn(M, N)
A.requires_grad=True
assert(torch.autograd.gradcheck(QR.apply, A, eps=1e-4, atol=1e-2))
print("Test Pass!")
if __name__ == "__main__":
test_qr()
'''
PyTorch has its own implementation of backward function for SVD https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1577
We reimplement it with a safe inverse function in light of degenerated singular values
'''
import numpy as np
import torch
def safe_inverse(x, epsilon=1E-12):
return x/(x**2 + epsilon)
class SVD(torch.autograd.Function):
@staticmethod
def forward(self, A):
U, S, V = torch.svd(A)
self.save_for_backward(U, S, V)
return U, S, V
@staticmethod
def backward(self, dU, dS, dV):
U, S, V = self.saved_tensors
Vt = V.t()
Ut = U.t()
M = U.size(0)
N = V.size(0)
NS = len(S)
F = (S - S[:, None])
F = safe_inverse(F)
F.diagonal().fill_(0)
G = (S + S[:, None])
G.diagonal().fill_(np.inf)
G = 1/G
UdU = Ut @ dU
VdV = Vt @ dV
Su = (F+G)*(UdU-UdU.t())/2
Sv = (F-G)*(VdV-VdV.t())/2
dA = U @ (Su + Sv + torch.diag(dS)) @ Vt
if (M>NS):
dA = dA + (torch.eye(M, dtype=dU.dtype, device=dU.device) - U@Ut) @ (dU/S) @ Vt
if (N>NS):
dA = dA + (U/S) @ dV.t() @ (torch.eye(N, dtype=dU.dtype, device=dU.device) - V@Vt)
return dA
def test_svd():
M, N = 50, 40
torch.manual_seed(2)
input = torch.rand(M, N, dtype=torch.float64, requires_grad=True)
assert(torch.autograd.gradcheck(SVD.apply, input, eps=1e-6, atol=1e-4))
print("Test Pass!")
if __name__=='__main__':
test_svd()
import torch
from tensornets import CTMRG
def ctmrg(T, chi, maxiter, epsilon, use_checkpoint=False):
'''
customize the backward of ctmrg contraction
we simply provide the adjoint of T so it does not need to diff into C and E
however, since we have enable_grad in the backward, the double backward will enter C and E
'''
class Contraction(torch.autograd.Function):
@staticmethod
def forward(self, T, chi, maxiter, epsilon, use_checkpoint):
self.chi = chi
self.maxiter = maxiter
self.epsilon = epsilon
self.use_checkpoint = use_checkpoint
C, E = CTMRG(T, chi, maxiter, epsilon, use_checkpoint=use_checkpoint)
Z1 = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C))
Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C))
Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C))
lnZ = torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs())
return lnZ
self.save_for_backward(T)
return torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs())
@staticmethod
def backward(self, dlnZ):
T, = self.saved_tensors
with torch.enable_grad():
C, E = CTMRG(T, self.chi, self.maxiter, self.epsilon, use_checkpoint=self.use_checkpoint)
up = torch.einsum('ab,bcd,fd,gha,fjk,lg,mil,mk->chij', (C,E,C,E,E,C,E,C)) * dlnZ
dn = torch.einsum('ab,bcd,fd,gha,chij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C))
return up/dn, None, None, None, None
def symmetrize(A):
'''
......@@ -24,3 +46,43 @@ def symmetrize(A):
return Asymm/Asymm.norm()
if __name__=='__main__':
import numpy as np
d = 2
D = 2
chi = 50
maxiter = 50
epsilon = 1E-10
dtype = torch.float64
#dimer covering
T = torch.zeros(d, d, d, d, d, d, dtype=dtype)
T[0, 0, 0, 0, 0, 1] = 1.0
T[0, 0, 0, 0, 1, 0] = 1.0
T[0, 0, 0, 1, 0, 0] = 1.0
T[0, 0, 1, 0, 0, 0] = 1.0
T[0, 1, 0, 0, 0, 0] = 1.0
T[1, 0, 0, 0, 0, 0] = 1.0
T = T.view(d, d**4, d)
contract = Contraction.apply
def f(x):
A = torch.as_tensor(x).view(d, D, D, D, D)
As = symmetrize(A).view(d, D**4)
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
return -contract(T2, chi, maxiter, epsilon)
def g(x):
A = torch.as_tensor(x).view(d, D, D, D, D).requires_grad_()
As = symmetrize(A).view(d, D**4)
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
loss = -contract(T2, chi, maxiter, epsilon)
loss.backward()
return (A.grad).numpy().ravel()
from scipy.optimize import check_grad
x = np.random.randn(d*D**4)
print (f(x), g(x).shape)
print ('gradient check:', check_grad(f, g, x))
import torch
from itertools import permutations
from adlib import SVD
svd = SVD.apply
def build_tensor(K, H):
c = torch.sqrt(torch.cosh(K))
s = torch.sqrt(torch.sinh(K))
Q = [c, s, c, -s]
Q = torch.stack(Q).view(2, 2)
M = [torch.exp(H), torch.exp(-H)]
M = torch.stack(M).view(2)
T = torch.einsum('a,ai,aj,ak,al->ijkl', (M, Q, Q, Q, Q))
return T
def ctmrg(T, d, Dcut, max_iter):
lnZ = 0.0
truncation_error = 0.0
#C = torch.randn(d, d, dtype=T.dtype, device=T.device) #T.sum((0,1))
#E = torch.randn(d, d, d, dtype=T.dtype, device=T.device)#T.sum(1)
C = T.sum((0,1))
E = T.sum(1)
D = d
sold = torch.zeros(d, dtype=T.dtype, device=T.device)
diff = 1E1
for n in range(max_iter):
A = torch.einsum('ab,eca,bdg,cdfh->efgh', (C, E, E, T)).contiguous().view(D*d, D*d)
A = (A+A.t())/2.
D_new = min(D*d, Dcut)
U, S, V = svd(A)
s = S/S.max()
truncation_error += S[D_new:].sum()/S.sum()
P = U[:, :D_new] # projection operator
#S, U = torch.symeig(A, eigenvectors=True)
#truncation_error += 1.-S[-D_new:].sum()/S.sum()
#P = U[:, -D_new:] # projection operator
C = (P.t() @ A @ P) #(D, D)
C = (C+C.t())/2.
ET = torch.einsum('ldr,adbc->labrc', (E, T)).contiguous().view(D*d, d, D*d)
#ET = torch.tensordot(E, T, dims=([1], [1]))
#ET = ET.permute(0, 2, 3, 1, 4).contiguous().view(D*d, d, D*d)
E = torch.einsum('li,ldr,rj->idj', (P, ET, P)) #(D_new, d, D_new)
#E = ( P.t() @ ((ET.view(D*d*d, D*d)@P).view(D*d, d*D_new))).view(D_new,d,D_new)
E = (E + E.permute(2, 1, 0))/2.
D = D_new
C = C/C.norm()
E = E/E.norm()
if (s.numel() == sold.numel()):
diff = (s-sold).norm()
if (diff < 1E-8):
break
sold = s
#print ('ctmrg iterations', n)
#C = C.detach()
#E = E.detach()
Z1 = torch.einsum('ab,bcd,fd,gha,hcij,fjk,lg,mil,mk', (C,E,C,E,T,E,C,E,C))
#CEC = torch.einsum('da,ebd,ce->abc', (C,E,C)).view(1, D**2*d)
#ETE = torch.einsum('abc,lbdr,mdn->almcrn',(E,T,E)).contiguous().view(D**2*d, D**2*d)
#ETE = (ETE+ETE.t())/2.
#Z1 = CEC@ETE@CEC.t()
Z3 = torch.einsum('ab,bc,cd,da', (C,C,C,C))
Z2 = torch.einsum('ab,bcd,de,fa,gcf,ge',(C,E,C,C,E,C))
#print (' Z1, Z2, Z3:', Z1.item(), Z2.item(), Z3.item())
lnZ += torch.log(Z1.abs()) + torch.log(Z3.abs()) - 2.*torch.log(Z2.abs())
return lnZ, truncation_error/n
if __name__=='__main__':
torch.set_num_threads(1)
beta = 0.44
J = 1.0
h = 0.0
beta = torch.tensor(beta, dtype=torch.float64).requires_grad_()
h = torch.tensor(h, dtype=torch.float64).requires_grad_()
Dcut = 100
max_iter = 1000
T = build_tensor(beta*J, beta*h)
lnZ, error = ctmrg(T, 2, Dcut, max_iter)
print (lnZ.item(), error.item())
dlnZ, = torch.autograd.grad(lnZ, h, create_graph=True)
dlnZ2, = torch.autograd.grad(dlnZ, h, create_graph=True)
print (dlnZ.item(), dlnZ2.item())
'''
This code solves counting problem [1] using differentiable TRG programming [2]
In a nutshell, it computes the maximum eigenvalue of tranfer matrix via variational principle
[1] Vanderstraeten et al, PRE 98, 042145 (2018)
[2] Levin and Nave, PRL 99, 120601 (2007)
'''
import torch
torch.set_num_threads(1)
torch.manual_seed(42)
from contraction import symmetrize, ctmrg
from contraction import symmetrize, Contraction
contract = Contraction.apply
if __name__=='__main__':
import argparse
parser = argparse.ArgumentParser(description='')
parser.add_argument("-D", type=int, default=2, help="D")
parser.add_argument("-chi", type=int, default=30, help="chi")
parser.add_argument("-maxiter", type=int, default=50, help="maxiter")
parser.add_argument("-epsilon", type=float, default=1E-10, help="maxiter")
parser.add_argument("-nepochs", type=int, default=10, help="nepochs")
parser.add_argument("-float32", action='store_true', help="use float32")
parser.add_argument("-use_checkpoint", action='store_true', help="use checkpoint")
parser.add_argument("-cuda", type=int, default=-1, help="use GPU")
parser.add_argument("--D", type=int, default=2, help="D")
parser.add_argument("--chi", type=int, default=50, help="chi")
parser.add_argument("--maxiter", type=int, default=50, help="maxiter")
parser.add_argument("--epsilon", type=float, default=1E-12, help="maxiter")
parser.add_argument("--nepochs", type=int, default=10, help="nepochs")
parser.add_argument("--float32", action='store_true', help="use float32")
parser.add_argument("--use_checkpoint", action='store_true', help="use checkpoint")
parser.add_argument("--cuda", type=int, default=-1, help="use GPU")
args = parser.parse_args()
device = torch.device("cpu" if args.cuda<0 else "cuda:"+str(args.cuda))
dtype = torch.float32 if args.float32 else torch.float64
......@@ -29,9 +24,9 @@ if __name__=='__main__':
d = 2 # fixed
D = args.D
B = 0.01* torch.randn(d, D, D, D, D, dtype=dtype, device=device)
#symmetrize initial boundary PEPS
A = torch.nn.Parameter(symmetrize(B))
B = torch.rand(d, D, D, D, D, dtype=dtype, device=device)
B = B/B.norm()
A = torch.nn.Parameter(B)
#dimer covering
T = torch.zeros(d, d, d, d, d, d, dtype=dtype, device=device)
......@@ -43,7 +38,7 @@ if __name__=='__main__':
T[1, 0, 0, 0, 0, 0] = 1.0
T = T.view(d, d**4, d)
optimizer = torch.optim.LBFGS([A], max_iter=20, tolerance_grad=0, tolerance_change=0, line_search_fn="strong_wolfe")
optimizer = torch.optim.LBFGS([A], max_iter=10, tolerance_grad=0, tolerance_change=0, line_search_fn="strong_wolfe")
def closure():
optimizer.zero_grad()
......@@ -53,8 +48,8 @@ if __name__=='__main__':
#double layer
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
lnT = ctmrg(T1, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
lnZ = ctmrg(T2, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
lnT = contract(T1, args.chi, args.maxiter, args.epsilon, args.use_checkpoint)
lnZ = contract(T2, args.chi, args.maxiter, args.epsilon, args.use_checkpoint)
loss = (-lnT + lnZ)
loss.backward()
......@@ -74,8 +69,8 @@ if __name__=='__main__':
T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d)
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
lnT = ctmrg(T1, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
lnZ = ctmrg(T2, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
lnT = contract(T1, args.chi, args.maxiter, args.epsilon, args.use_checkpoint)
lnZ = contract(T2, args.chi, args.maxiter, args.epsilon, args.use_checkpoint)
loss = (-lnT + lnZ)
loss.backward(create_graph=True)
......@@ -84,6 +79,8 @@ if __name__=='__main__':
info['loss'] = loss
info['A'] = A
print ('A.grad in fun', A.grad)
print (info['feval'], loss.item(), A.grad.norm().item())
return loss.item(), A.grad.detach().cpu().numpy().ravel()
......@@ -99,8 +96,8 @@ if __name__=='__main__':
T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d)
T2 = (As.t()@As).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
lnT = ctmrg(T1, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
lnZ = ctmrg(T2, args.chi, args.maxiter, args.epsilon, use_checkpoint=args.use_checkpoint)
lnT = contract(T1, args.chi, args.maxiter, args.epsilon, args.use_checkpoint)
lnZ = contract(T2, args.chi, args.maxiter, args.epsilon, args.use_checkpoint)
loss = (-lnT + lnZ)
loss.backward(create_graph=True)
......@@ -116,7 +113,6 @@ if __name__=='__main__':
A.grad.data = Agrad #put it back
return res
import scipy.optimize
x0 = A.detach().cpu().numpy().ravel()
x = scipy.optimize.minimize(fun, x0, args=({'feval':0},), jac=True, hessp=hessp, method='Newton-CG', options={'xtol':1E-8})
from .ctmrg import CTMRG
from .singlelayer import projection, ctmrg, vumps
from .utils import save_checkpoint, load_checkpoint, kronecker_product, symmetrize
from .ctmrg import CTMRG, renormalize
......@@ -84,25 +84,4 @@ def CTMRG(T, chi, max_iter, epsilon, use_checkpoint=False):
return C, E
if __name__=='__main__':
import time
torch.manual_seed(42)
D = 6
chi = 80
max_iter = 100
device = 'cpu'
# T(u,l,d,r)
T = torch.randn(D, D, D, D, dtype=torch.float64, device=device, requires_grad=True)
T = (T + T.permute(0, 3, 2, 1))/2. # left-right symmetry
T = (T + T.permute(2, 1, 0, 3))/2. # up-down symmetry
T = (T + T.permute(3, 2, 1, 0))/2. # skew-diagonal symmetry
T = (T + T.permute(1, 0, 3, 2))/2. # digonal symmetry
T = T/T.norm()
C, E = CTMRG(T, chi, max_iter, use_checkpoint=True)
C, E = CTMRG(T, chi, max_iter, use_checkpoint=False)
print( 'diffC = ', torch.dist( C, C.t() ) )
print( 'diffE = ', torch.dist( E, E.permute(2,1,0) ) )
import torch
from adlib import SVD
svd = SVD.apply
def levin_nave_trg(T, D, Dcut, no_iter):
lnZ = 0.0
truncation_error = 0.0
for n in range(no_iter):
#print(n, " ", T.max(), " ", T.min())
maxval = T.max()
T = T/maxval
lnZ += 2**(no_iter-n)*torch.log(maxval)
#print (2**(no_iter-n)*torch.log(maxval)/2**(no_iter))
Ma = T.permute(2, 1, 0, 3).contiguous().view(D**2, D**2)
Mb = T.permute(3, 2, 1, 0).contiguous().view(D**2, D**2)
Ua, Sa, Va = svd(Ma)
Ub, Sb, Vb = svd(Mb)
D_new = min(D**2, Dcut)
truncation_error += Sa[D_new:].sum()/Sa.sum() + Sb[D_new:].sum()/Sb.sum()
S1 = (Ua[:, :D_new]* torch.sqrt(Sa[:D_new])).view(D, D, D_new)
S3 = (Va[:, :D_new]* torch.sqrt(Sa[:D_new])).view(D, D, D_new)
S2 = (Ub[:, :D_new]* torch.sqrt(Sb[:D_new])).view(D, D, D_new)
S4 = (Vb[:, :D_new]* torch.sqrt(Sb[:D_new])).view(D, D, D_new)
#T_new = torch.einsum('war,abu,bgl,gwd->ruld', (S1, S2, S3, S4))
T_new = torch.tensordot( torch.tensordot(S1, S2, dims=([1], [0])),
torch.tensordot(S3, S4, dims=([1], [0])), dims=([2, 0], [0, 2]))
D = D_new
T = T_new
trace = 0.0
for x in range(D):
for y in range(D):
trace += T[x, y, y, x]
#print (torch.log(trace)/2**no_iter)
lnZ += torch.log(trace)
return lnZ/2**no_iter, truncation_error
if __name__=='__main__':
K = torch.tensor([0.44])
Dcut = 20
n = 20
c = torch.sqrt(torch.cosh(K)/2.)
s = torch.sqrt(torch.sinh(K)/2.)
M = torch.stack([torch.cat([c+s, c-s]), torch.cat([c-s, c+s])])
T = torch.einsum('ai,aj,ak,al->ijkl', (M, M, M, M))
lnZ, error = levin_nave_trg(T, 2, Dcut, n)
print (lnZ.item(), error)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment