Advanced Computing Platform for Theoretical Physics

qr.py 1.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch

class QR(torch.autograd.Function):
    @staticmethod
    def forward(self, A):
        Q, R = torch.qr(A)
        self.save_for_backward(A, Q, R)
        return Q, R

    @staticmethod
    def backward(self, dq, dr):
        A, q, r = self.saved_tensors
        if r.shape[0] == r.shape[1]:
            return _simple_qr_backward(q, r, dq ,dr)
        M, N = r.shape
        B = A[:,M:]
        dU = dr[:,:M]
        dD = dr[:,M:]
        U = r[:,:M]
        da = _simple_qr_backward(q, U, dq+B@dD.t(), dU)
        db = q@dD
        return torch.cat([da, db], 1)

def _simple_qr_backward(q, r, dq, dr):
    if r.shape[-2] != r.shape[-1]:
        raise NotImplementedError("QrGrad not implemented when ncols > nrows "
                          "or full_matrices is true and ncols != nrows.")

    qdq = q.t() @ dq
    qdq_ = qdq - qdq.t()
    rdr = r @ dr.t()
    rdr_ = rdr - rdr.t()
    tril = torch.tril(qdq_ + rdr_)

    def _TriangularSolve(x, r):
        """Equiv to x @ torch.inverse(r).t() if r is upper-tri."""
        res = torch.trtrs(x.t(), r, upper=True, transpose=False)[0].t()
        return res

    grad_a = q @ (dr + _TriangularSolve(tril, r))
    grad_b = _TriangularSolve(dq - q @ qdq, r)
    return grad_a + grad_b

def test_qr():
    M, N = 4, 6
    torch.manual_seed(2)
    A = torch.randn(M, N)
    A.requires_grad=True
    assert(torch.autograd.gradcheck(QR.apply, A, eps=1e-4, atol=1e-2))
    print("Test Pass!")

if __name__ == "__main__":
    test_qr()