Advanced Computing Platform for Theoretical Physics

commit大文件会使得服务器变得不稳定,请大家尽量只commit代码,不要commit大的文件。

dimer_covering.py 3.14 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
'''
This code solves counting problem [1] using differentiable TRG programming [2] 
In a nutshell, it computes the maximum eigenvalue of tranfer matrix via variational principle
[1] Vanderstraeten et al, PRE 98, 042145 (2018)
[2] Levin and Nave, PRL 99, 120601 (2007)
'''
import torch 
torch.set_num_threads(4)
torch.manual_seed(42)

Lei Wang's avatar
Lei Wang committed
11
from vmps import vmps as contraction
12 13 14 15 16 17

if __name__=='__main__':
    import time
    import argparse
    parser = argparse.ArgumentParser(description='')
    parser.add_argument("-D", type=int, default=2, help="D")
18 19
    parser.add_argument("-Dcut", type=int, default=20, help="Dcut")
    parser.add_argument("-Niter", type=int, default=10, help="Niter")
20 21

    parser.add_argument("-float32", action='store_true', help="use float32")
Lei Wang's avatar
Lei Wang committed
22
    parser.add_argument("-lanczos_steps", type=int, default=0, help="lanczos steps")
23 24 25 26 27
    parser.add_argument("-cuda", type=int, default=-1, help="use GPU")
    args = parser.parse_args()
    device = torch.device("cpu" if args.cuda<0 else "cuda:"+str(args.cuda))
    dtype = torch.float32 if args.float32 else torch.float64

Lei Wang's avatar
Lei Wang committed
28
    if args.lanczos>0: print ('lanczos steps', args.lanczos_steps)
29

30 31 32 33 34 35
    d = 2 # fixed

    D = args.D
    Dcut = args.Dcut
    Niter = args.Niter

Lei Wang's avatar
Lei Wang committed
36
    B = 0.01* torch.randn(d, D, D, D, D, dtype=dtype, device=device)
Lei Wang's avatar
Lei Wang committed
37
    #symmetrize initial boundary PEPS
Lei Wang's avatar
Lei Wang committed
38 39 40 41 42
    B = (B + B.permute(0, 4, 2, 3, 1))/2. 
    B = (B + B.permute(0, 1, 3, 2, 4))/2. 
    B = (B + B.permute(0, 3, 4, 1, 2))/2. 
    B = (B + B.permute(0, 2, 1, 4, 3))/2. 
    A = torch.nn.Parameter(B.view(d, D**4))
Lei Wang's avatar
Lei Wang committed
43 44
    
    #boundary MPS
Lei Wang's avatar
Lei Wang committed
45 46
    A1 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2*d, Dcut, dtype=dtype, device=device))
    A2 = torch.nn.Parameter(0.01*torch.randn(Dcut, D**2, Dcut, dtype=dtype, device=device))
Lei Wang's avatar
Lei Wang committed
47

48 49 50 51 52 53 54 55 56 57
    #dimer covering
    T = torch.zeros(d, d, d, d, d, d, dtype=dtype, device=device)
    T[0, 0, 0, 0, 0, 1] = 1.0 
    T[0, 0, 0, 0, 1, 0] = 1.0 
    T[0, 0, 0, 1, 0, 0] = 1.0 
    T[0, 0, 1, 0, 0, 0] = 1.0 
    T[0, 1, 0, 0, 0, 0] = 1.0 
    T[1, 0, 0, 0, 0, 0] = 1.0 
    T = T.view(d, d**4, d)

Lei Wang's avatar
Lei Wang committed
58 59
    optimizer = torch.optim.LBFGS([A], max_iter=20)
    #optimizer = torch.optim.Adam([A])
60 61 62 63 64 65 66 67
    
    def closure():
        optimizer.zero_grad()

        T1 = torch.einsum('xa,xby,yc' , (A,T,A)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d)
        #double layer 
        T2 = (A.t()@A).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
        t0=time.time()
Lei Wang's avatar
Lei Wang committed
68 69
        lnT = contraction(T1, D**2*d, Dcut, Niter, A1, lanczos_steps=args.lanczos_steps)
        lnZ = contraction(T2, D**2, Dcut, Niter, A2, lanczos_steps=args.lanczos_steps)
70
        loss = (-lnT + lnZ)
Lei Wang's avatar
Lei Wang committed
71
        print ('    contraction done {:.3f}s'.format(time.time()-t0))
Lei Wang's avatar
Lei Wang committed
72 73
        print ('    total loss', loss.item())
        #print ('    loss, error', loss.item(), error1.item(), error2.item())
74 75 76

        t0=time.time()
        loss.backward()
Lei Wang's avatar
Lei Wang committed
77
        print ('    backward done {:.3f}s'.format(time.time()-t0))
78
        return loss
79 80 81
    
    for epoch in range(100):
        loss = optimizer.step(closure)
82
        print ('epoch, residual entropy', epoch, -loss.item())