Advanced Computing Platform for Theoretical Physics

Commit e3f4b90c authored by Lei Wang's avatar Lei Wang
Browse files

runnable on gpu

parent 5e342755
...@@ -90,7 +90,7 @@ if __name__=='__main__': ...@@ -90,7 +90,7 @@ if __name__=='__main__':
####################################################################3 ####################################################################3
def fun(x, info): def fun(x, info):
A = torch.as_tensor(x).requires_grad_() A = torch.as_tensor(x, device=device).requires_grad_()
As = symmetrize(A) As = symmetrize(A)
T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d) T1 = torch.einsum('xa,xby,yc' , (As,T,As)).view(D,D,D,D, d,d,d,d, D,D,D,D).permute(0,4,8, 1,5,9, 2,6,10, 3,7,11).contiguous().view(D**2*d, D**2*d, D**2*d, D**2*d)
...@@ -106,18 +106,18 @@ if __name__=='__main__': ...@@ -106,18 +106,18 @@ if __name__=='__main__':
info['A'] = A info['A'] = A
print (info['feval'], loss.item(), A.grad.norm().item()) print (info['feval'], loss.item(), A.grad.norm().item())
return loss.item(), A.grad.detach().numpy().ravel() return loss.item(), A.grad.detach().cpu().numpy().ravel()
# see this https://code.itp.ac.cn/wanglei/dpeps/commit/2e47d663bb2c8e155967c4b644edf63d9e9b22e9 # see this https://code.itp.ac.cn/wanglei/dpeps/commit/2e47d663bb2c8e155967c4b644edf63d9e9b22e9
def hessp(x, p, info): def hessp(x, p, info):
A = info['A'] A = info['A']
if ((x == A.detach().numpy().ravel()).all()): if ((x == A.detach().cpu().numpy().ravel()).all()):
grad = A.grad.view(-1) grad = A.grad.view(-1)
else: else:
print ('recalculate forward and grad') print ('recalculate forward and grad')
A = torch.as_tensor(A).requires_grad_() A = torch.as_tensor(A, device=device).requires_grad_()
As = symmetrize(A) As = symmetrize(A)
...@@ -132,11 +132,11 @@ if __name__=='__main__': ...@@ -132,11 +132,11 @@ if __name__=='__main__':
grad = A.grad.view(-1) grad = A.grad.view(-1)
#dot it with the given vector #dot it with the given vector
loss = torch.dot(grad, torch.as_tensor(p)) loss = torch.dot(grad, torch.as_tensor(p, device=device))
hvp = torch.autograd.grad(loss, A, retain_graph=True)[0].view(-1) hvp = torch.autograd.grad(loss, A, retain_graph=True)[0].view(-1)
return hvp.numpy().ravel() return hvp.cpu().numpy().ravel()
import scipy.optimize import scipy.optimize
x0 = A.detach().numpy().ravel() x0 = A.detach().cpu().numpy().ravel()
x = scipy.optimize.minimize(fun, x0, args=({'feval':0},), jac=True, hessp=hessp, method='Newton-CG', options={'xtol':1E-8}) x = scipy.optimize.minimize(fun, x0, args=({'feval':0},), jac=True, hessp=hessp, method='Newton-CG', options={'xtol':1E-8})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment