Advanced Computing Platform for Theoretical Physics

fixedpoint.py 3.87 KB
Newer Older
1
2
import torch

3
def step(T, C, E, chi):
4

5
6
    dimT, dimE = T.shape[0], E.shape[0]
    D_new = min(dimE*dimT, chi)
7
8
9
10
11

    # step 1: contruct the density matrix Rho
    Rho = torch.tensordot(C,E,([1],[0]))        # C(ef)*EU(fga)=Rho(ega)
    Rho = torch.tensordot(Rho,E,([0],[0]))      # Rho(ega)*EL(ehc)=Rho(gahc)
    Rho = torch.tensordot(Rho,T,([0,2],[0,1]))  # Rho(gahc)*T(ghdb)=Rho(acdb)
12
    Rho = Rho.permute(0,3,1,2).contiguous().view(dimE*dimT, dimE*dimT)  # Rho(acdb)->Rho(ab;cd)
13
14
15
16
17

    Rho = Rho+Rho.t()
    Rho = Rho/Rho.norm()
   
    # step 2: Get Isometry P
18
19
20
    U, S, V = torch.svd(Rho)
    truncation_error = S[D_new:].sum()/S.sum()
    P = U[:, :D_new] # projection operator
21
22
    
    #can also do symeig since Rho is symmetric 
23
24
25
26
27
28
29
30
    #S, U = torch.symeig(Rho, eigenvectors=True)
    #sorted, indices = torch.sort(S.abs(), descending=True)
    #truncation_error = sorted[D_new:].sum()/sorted.sum()
    #S = S[indices][:D_new]
    #P = U[:, indices][:, :D_new] # projection operator
    
    #fix gauge, first row should be positive
    P = P*P[0, :].sign()
31
32
33
34
35

    # step 3: renormalize C and E
    C = (P.t() @ Rho @ P) #C(D_new, D_new)

    ## EL(u,r,d)
36
    P = P.view(dimE,dimT,D_new)
37
38
39
40
41
42
43
44
    E = torch.tensordot(E, P, ([0],[0]))  # EL(def)P(dga)=E(efga)
    E = torch.tensordot(E, T, ([0,2],[1,0]))  # E(efga)T(gehb)=E(fahb)
    E = torch.tensordot(E, P, ([0,2],[0,1]))  # E(fahb)P(fhc)=E(abc)

    # step 4: symmetrize C and E
    C = 0.5*(C+C.t())
    E = 0.5*(E + E.permute(2, 1, 0))

45
    return C/C.norm(), E/E.norm(), S.abs()/S.abs().max()
46
47
48

class CTMRG(torch.autograd.Function):
    @staticmethod
49
50
51
52
53
    def forward(ctx, T, chi, maxiter=50, tol=1E-12):

        C = T.sum((0,1))  #
        E = T.sum(1).permute(0,2,1)
        s = torch.zeros(chi, dtype=T.dtype, device=T.device)
54
        for n in range(maxiter):
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
            C_new, E_new, s_new = step(T, C, E, chi) 

            if (s.numel() == s_new.numel()):
                assert(C_new.numel() == chi**2)
                diff1 = torch.dist(C, C_new)
                diff2 = torch.dist(E, E_new) 
                diff = torch.dist(s, s_new)
                print (n, diff1.item(), diff2.item(), diff.item())
                if (diff1 < tol and diff2 < tol):
                    break
            C = C_new 
            E = E_new
            s = s_new

        ctx.save_for_backward(T, C, E)
        return C, E
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    @staticmethod
    def backward(ctx, grad):
        T, x_star = detach_variable(ctx.saved_tensors)
        dT = grad
        for n in range(args.Maxiter):
            with torch.enable_grad():
                x = step(T, x_star)
                grad = torch.autograd.grad(x, x_star, grad_outputs=grad)[0]
            grad_norm = torch.norm(grad)
            if (grad_norm > args.tol):
                dT = dT + grad
            else:
                break
        print ('backward converged to', n, grad_norm.item())
        with torch.enable_grad():
            x = step(T, x_star)
            dT = torch.autograd.grad(x, T, grad_outputs=dT)[0]
        return dT, None, None, None, None 

if __name__=='__main__':
    import time
    torch.manual_seed(42)
    d = 2
95
    chi = 40
96
97
98
99
100
101
102
103
104
105
106
    device = 'cpu'
    dtype = torch.float64

    # T(u,l,d,r)
    T = torch.zeros(d, d, d, d, dtype=dtype, device=device)
    T[0, 0, 0, 1] = 1.0 
    T[0, 0, 1, 0] = 1.0 
    T[0, 1, 0, 0] = 1.0 
    T[1, 0, 0, 0] = 1.0 
    
    ctmrg = CTMRG.apply
107
    C, E = ctmrg(T, chi, 1000, 1E-6)
108
109

    x = torch.cat([C.view(-1), E.view(-1)]).numpy()
110
111

    def fun(x):
112
113
        C = torch.as_tensor(x[:chi**2], dtype=T.dtype, device=T.device).view(chi, chi)
        E = torch.as_tensor(x[chi**2:], dtype=T.dtype, device=T.device).view(chi, d, chi)
114

115
116
        C, E, _ = step(T, C, E, chi)
        return torch.cat([C.view(-1), E.view(-1)]).numpy() - x 
117

118
    from scipy import optimize
119
    sol = optimize.root(fun, x, method='krylov', options={'fatol':1E-13, 'disp': True, 'jac_options':{'method': 'lgmres'} })
120
    print (sol)