import torch torch.set_num_threads(4) def vmps(T, d, D, no_iter, Nepochs): A = torch.nn.Parameter(0.01*torch.randn(D, d, D, dtype=torch.float64, device=device)) def mpsrg(B, C): lnZ1 = 0.0 lnZ2 = 0.0 for i in range(no_iter): s = B.norm(1) lnZ1 = lnZ1 + torch.log(s)/2**i B = B/s B = torch.mm(B, B) s = C.norm(1) lnZ2 = lnZ2 + torch.log(s)/2**i C = C/s C = torch.mm(C, C) #lnZ1 = lnZ1 + torch.log(torch.trace(B))/2**no_iter #lnZ2 = lnZ2 + torch.log(torch.trace(C))/2**no_iter #print (torch.log(torch.trace(B))/2**no_iter, torch.log(torch.trace(C))/2**no_iter) return lnZ1 , lnZ2 optimizer = torch.optim.LBFGS([A], max_iter=20) def closure(): optimizer.zero_grad() Asymm = (A + A.permute(2, 1, 0))*0.5 #t0 = time.time() B = torch.einsum('ldr,adcb,icj->lairbj', (Asymm, T, Asymm)).contiguous().view(D**2*d, D**2*d) C = torch.einsum('ldr,idj->lirj', (Asymm, Asymm)).contiguous().view(D**2, D**2) #print ('einsum', time.time()- t0) #t0 = time.time() lnZ1, lnZ2= mpsrg(B, C) #print ('mpsrg', time.time()- t0) loss = -lnZ1 + lnZ2 print (' loss', loss.item(), lnZ1.item(), lnZ2.item()) #t0 = time.time() loss.backward() #print ('backward', time.time()- t0) return loss for epoch in range(Nepochs): loss = optimizer.step(closure) print ('epoch, free energy', epoch, loss.item()) return -loss if __name__=='__main__': import time import argparse parser = argparse.ArgumentParser(description='') parser.add_argument("-D", type=int, default=2, help="D") parser.add_argument("-Dcut", type=int, default=20, help="Dcut") parser.add_argument("-beta", type=float, default=0.44, help="beta") parser.add_argument("-Niter", type=int, default=32, help="Niter") parser.add_argument("-Nepochs", type=int, default=100, help="Nepochs") parser.add_argument("-float32", action='store_true', help="use float32") parser.add_argument("-cuda", type=int, default=-1, help="use GPU") args = parser.parse_args() device = torch.device("cpu" if args.cuda<0 else "cuda:"+str(args.cuda)) dtype = torch.float32 if args.float32 else torch.float64 K = torch.tensor([args.beta], dtype=torch.float64, device=device) c = torch.sqrt(torch.cosh(K)) s = torch.sqrt(torch.sinh(K)) M = torch.stack([torch.cat([c, s]), torch.cat([c, -s])]) T = torch.einsum('ai,aj,ak,al->ijkl', (M, M, M, M)) lnZ = vmps(T, 2, args.Dcut, args.Niter, args.Nepochs)