import torch from adlib import SVD svd = SVD.apply def levin_nave_trg(T, D, Dcut, no_iter): lnZ = 0.0 truncation_error = 0.0 for n in range(no_iter): #print(n, " ", T.max(), " ", T.min()) maxval = T.max() T = T/maxval lnZ += 2**(no_iter-n)*torch.log(maxval) #print (2**(no_iter-n)*torch.log(maxval)/2**(no_iter)) Ma = T.permute(2, 1, 0, 3).contiguous().view(D**2, D**2) Mb = T.permute(3, 2, 1, 0).contiguous().view(D**2, D**2) Ua, Sa, Va = svd(Ma) Ub, Sb, Vb = svd(Mb) D_new = min(D**2, Dcut) truncation_error += Sa[D_new:].sum()/Sa.sum() + Sb[D_new:].sum()/Sb.sum() S1 = (Ua[:, :D_new]* torch.sqrt(Sa[:D_new])).view(D, D, D_new) S3 = (Va[:, :D_new]* torch.sqrt(Sa[:D_new])).view(D, D, D_new) S2 = (Ub[:, :D_new]* torch.sqrt(Sb[:D_new])).view(D, D, D_new) S4 = (Vb[:, :D_new]* torch.sqrt(Sb[:D_new])).view(D, D, D_new) #T_new = torch.einsum('war,abu,bgl,gwd->ruld', (S1, S2, S3, S4)) T_new = torch.tensordot( torch.tensordot(S1, S2, dims=([1], [0])), torch.tensordot(S3, S4, dims=([1], [0])), dims=([2, 0], [0, 2])) D = D_new T = T_new trace = 0.0 for x in range(D): for y in range(D): trace += T[x, y, y, x] #print (torch.log(trace)/2**no_iter) lnZ += torch.log(trace) return lnZ/2**no_iter, truncation_error if __name__=='__main__': K = torch.tensor([0.44]) Dcut = 20 n = 20 c = torch.sqrt(torch.cosh(K)/2.) s = torch.sqrt(torch.sinh(K)/2.) M = torch.stack([torch.cat([c+s, c-s]), torch.cat([c-s, c+s])]) T = torch.einsum('ai,aj,ak,al->ijkl', (M, M, M, M)) lnZ, error = levin_nave_trg(T, 2, Dcut, n) print (lnZ.item(), error)