Advanced Computing Platform for Theoretical Physics

Commit 685e853b authored by Lei Wang's avatar Lei Wang
Browse files

added a test on updates

parent 28563c09
import torch
torch.manual_seed(42)
dtype = torch.float64
device = 'cuda:0'
N = 100
chi = 20
A = torch.randn(N, N, dtype=dtype, device=device)
U, S, V = torch.svd(A)
P = U[:, :chi]
projected_trace = torch.trace(A@A.t() @ P @P.t())
target = ((S[:chi])**2).sum()
print (projected_trace.item(), target.item())
w = torch.randn(N, chi, dtype=dtype, device=device)
for step in range(20):
E = w.t()@A@A.t()
U, S, V = torch.svd(E)
w = V@U.t()
loss = torch.trace(A@A.t() @ w @ w.t())
print (step, ((loss-target)/target).item(), ((projected_trace-target)/target).item())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment