diff --git a/dimer_covering.py b/dimer_covering.py index 8668e2787d8154ef1bdd7149d6f964f7f293c6bb..7217566e37ee499eb962f49658527ffdad796019 100644 --- a/dimer_covering.py +++ b/dimer_covering.py @@ -90,7 +90,7 @@ if __name__=='__main__': ####################################################################3 def fun(x, info): - A = torch.as_tensor(x).requires_grad_() + A = torch.as_tensor(x, device=device).requires_grad_() As = symmetrize(A) T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d) @@ -106,18 +106,18 @@ if __name__=='__main__': info['A'] = A print (info['feval'], loss.item(), A.grad.norm().item()) - return loss.item(), A.grad.detach().numpy().ravel() + return loss.item(), A.grad.detach().cpu().numpy().ravel() # see this https://code.itp.ac.cn/wanglei/dpeps/commit/2e47d663bb2c8e155967c4b644edf63d9e9b22e9 def hessp(x, p, info): A = info['A'] - if ((x == A.detach().numpy().ravel()).all()): + if ((x == A.detach().cpu().numpy().ravel()).all()): grad = A.grad.view(-1) else: print ('recalculate forward and grad') - A = torch.as_tensor(A).requires_grad_() + A = torch.as_tensor(A, device=device).requires_grad_() As = symmetrize(A) @@ -132,11 +132,11 @@ if __name__=='__main__': grad = A.grad.view(-1) #dot it with the given vector - loss = torch.dot(grad, torch.as_tensor(p)) + loss = torch.dot(grad, torch.as_tensor(p, device=device)) hvp = torch.autograd.grad(loss, A, retain_graph=True)[0].view(-1) - return hvp.numpy().ravel() + return hvp.cpu().numpy().ravel() import scipy.optimize - x0 = A.detach().numpy().ravel() + x0 = A.detach().cpu().numpy().ravel() x = scipy.optimize.minimize(fun, x0, args=({'feval':0},), jac=True, hessp=hessp, method='Newton-CG', options={'xtol':1E-8})