Advanced Computing Platform for Theoretical Physics

commit大文件会使得服务器变得不稳定,请大家尽量只commit代码,不要commit大的文件。

Commit 246dca5e authored by Lei Wang's avatar Lei Wang
Browse files

added vmps

parent 400ab95a
import torch
torch.set_num_threads(4)
def vmps(T, d, D, no_iter, Nepochs):
A = torch.nn.Parameter(0.01*torch.randn(D, d, D, dtype=torch.float64, device=device))
def mpsrg(B, C):
lnZ1 = 0.0
lnZ2 = 0.0
for i in range(no_iter):
s = B.norm(1)
lnZ1 = lnZ1 + torch.log(s)/2**i
B = B/s
B = torch.mm(B, B)
s = C.norm(1)
lnZ2 = lnZ2 + torch.log(s)/2**i
C = C/s
C = torch.mm(C, C)
#lnZ1 = lnZ1 + torch.log(torch.trace(B))/2**no_iter
#lnZ2 = lnZ2 + torch.log(torch.trace(C))/2**no_iter
#print (torch.log(torch.trace(B))/2**no_iter, torch.log(torch.trace(C))/2**no_iter)
return lnZ1 , lnZ2
optimizer = torch.optim.LBFGS([A], max_iter=20)
def closure():
optimizer.zero_grad()
Asymm = (A + A.permute(2, 1, 0))*0.5
#t0 = time.time()
B = torch.einsum('ldr,adcb,icj->lairbj', (Asymm, T, Asymm)).contiguous().view(D**2*d, D**2*d)
C = torch.einsum('ldr,idj->lirj', (Asymm, Asymm)).contiguous().view(D**2, D**2)
#print ('einsum', time.time()- t0)
#t0 = time.time()
lnZ1, lnZ2= mpsrg(B, C)
#print ('mpsrg', time.time()- t0)
loss = -lnZ1 + lnZ2
print (' loss', loss.item(), lnZ1.item(), lnZ2.item())
#t0 = time.time()
loss.backward()
#print ('backward', time.time()- t0)
return loss
for epoch in range(Nepochs):
loss = optimizer.step(closure)
print ('epoch, free energy', epoch, loss.item())
return -loss
if __name__=='__main__':
import time
import argparse
parser = argparse.ArgumentParser(description='')
parser.add_argument("-D", type=int, default=2, help="D")
parser.add_argument("-Dcut", type=int, default=20, help="Dcut")
parser.add_argument("-beta", type=float, default=0.44, help="beta")
parser.add_argument("-Niter", type=int, default=32, help="Niter")
parser.add_argument("-Nepochs", type=int, default=100, help="Nepochs")
parser.add_argument("-float32", action='store_true', help="use float32")
parser.add_argument("-cuda", type=int, default=-1, help="use GPU")
args = parser.parse_args()
device = torch.device("cpu" if args.cuda<0 else "cuda:"+str(args.cuda))
dtype = torch.float32 if args.float32 else torch.float64
K = torch.tensor([args.beta], dtype=torch.float64, device=device)
c = torch.sqrt(torch.cosh(K))
s = torch.sqrt(torch.sinh(K))
M = torch.stack([torch.cat([c, s]), torch.cat([c, -s])])
T = torch.einsum('ai,aj,ak,al->ijkl', (M, M, M, M))
lnZ = vmps(T, 2, args.Dcut, args.Niter, args.Nepochs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment