We have $$ \begin{align*} (x - a)^T A^{-1} (x - a) + (x - b)^T B^{-1} (x - b) &= x^T (A^{-1} + B^{-1}) x - 2x^T(A^{-1}a + B^{-1}b) + (a^T A^{-1} a + b^T B^{-1} b) \\ &= x^T C^{-1} x - 2 x^T C^{-1} c + c^T C^{-1} c + \bbox[yellow]d \\ &= (x - c)^T C^{-1} (x - c) + \bbox[yellow]d, \end{align*} $$ where $$ C = (A^{-1} + B^{-1})^{-1}, $$ $$ c = C(A^{-1}a + B^{-1}b), $$ $$ \bbox[yellow]d = a^T A^{-1} a + b^T B^{-1} b - c^T C^{-1} c. $$
On the other hand, we have: $$ \begin{align*} c^T C^{-1} c &= c^T (A^{-1}a + B^{-1}b) \\ &= (a^T A^{-1} + b^T B^{-1})(A^{-1} + B^{-1})^{-1} (A^{-1}a + B^{-1}b), \\ &= a^T E a + b^T F b + 2a^T G b, \end{align*} $$ where $$ E = A^{-1}(A^{-1} + B^{-1})^{-1}A^{-1} = (A + AB^{-1}A)^{-1} \overset{Woodbury}{=} A^{-1} - (A + B)^{-1}, $$ $$ F = (similarly) = B^{-1} - (A + B)^{-1}, $$ $$ G = A^{-1}(A^{-1} + B^{-1})^{-1}B^{-1} = (A + B)^{-1}. $$
Hence from $$ c^T C^{-1} c = a^T A^{-1} a + b^T B^{-1} b - (a - b)^T (A + B)^{-1} (a - b), $$ we come up with other derivations of $\bbox[yellow]d$: $$ \bbox[yellow]d = \bbox[orange]{(a - b)^T (A + B)^{-1} (a - b)} = (a - b)^T A^{-1}(A^{-1} + B^{-1})^{-1}B^{-1} (a - b) = \bbox[yellow]{(a - b)^T A^{-1}CB^{-1} (a - b)}. $$
Define $L_a = Cholesky(A)$, $L_b = Cholesky(B)$, $L_c = Cholesky(C)$, we have $$ \begin{align*} C &= (L_a^{-T} L_a^{-1} + L_b^{-T} L_b^{-1})^{-1} \\ &= L_a(I + L_a^{T}L_b^{-T} L_b^{-1}L_a)^{-1}L_a^T. \end{align*} $$ So it is enough to solve $L = L_b \backslash L_a$, compute $L_d = Cholesky(I + L^TL)$ (computing Cholesky here is good because eigen values of $I + L^TL$ are larger than $1$), and finally define $L_c = L_aL_d$.
Now, we compute $c$: $$ CA^{-1}a = L_c L_d^T L_a^T L_a^{-T} L_a^{-1} a = L_c L_d^T (L_a \backslash a), $$ $$ CB^{-1}b = L_c L_d^T L_a^T L_b^{-T} L_b^{-1} b = L_c L_d^T L^T (L_b \backslash b), $$ $$ c = L_c L_d^T \left[ L_a\backslash a + L^T (L_b\backslash b) \right]. $$
$L_a\backslash a$ and $L_b\backslash b$ can be used to compute the corresponding Mahalanobis terms of $\bbox[yellow]d$. We can also compute $$ L_c \backslash c = L_d^T \left[ L_a\backslash a + L^T (L_b\backslash b) \right]. $$
To make it compatiable with the real log density of Gaussian, we need to divide $d$ by $2$. It seems that we need to add the following adjusted normalization constant (scaled by a factor of $2$) to $d$ too $$ \begin{align*} \delta(normalization) &= n\log(2\pi) + \log |A| + n\log(2\pi) + \log |B| - n\log(2\pi) - \log |C| \\ &= n\log(2\pi) + \log |AC^{-1}B| = \bbox[orange]{n\log(2\pi) + \log |A + B|} \\ &= n\log(2\pi) + 2 \sum \left[ \log diag(L_a)) + \log diag(L_b) - \log diag(L_c) \right]. \end{align*} $$
From $\bbox[orange]{box}$, it looks like that $$ d = -\log \mathcal{N}(a - b \mid 0, A + B). $$
Define flip(X) = X[::-1, ::-1]
, we have flip(AB) = flip(A)flip(B)
. Indeed,
$$flip(AB)[i, j] = (AB)[-i, -j] = A[-i,:] * B[:,-j] = flip(A)[i,:] * flip(B)[:,j].$$
Apply this operator, from $$ C = LL^T, $$ we have $$ P = C^{-1} = L^{-T}L^{-1}. $$ Hence $$ flip(P) = flip(L^{-T})flip(L^{-1}). $$ So $$ flip(L^{-T}) = Cholesky(flip(P)) $$ and $$ L = flip(Cholesky(flip(P)))^{-T} $$
import torch
def precision_to_scale_tril(P):
Lf = torch.cholesky(torch.flip(P, (-1, -2)))
L = torch.inverse(torch.transpose(torch.flip(Lf, (-1, -2)), -1, -2))
return L.tril() # torch.inverse of a triangular is not a triangular due to precision
A = torch.randn(5, 5)
C = A.matmul(A.t())
P = torch.inverse(C)
L = precision_to_scale_tril(P)
print("test for agreement on covariance:")
print(C)
print(L.matmul(L.t()))
print("===============")
print("test for agreement on scale_tril:")
print(L)
print(torch.cholesky(C))
test for agreement on covariance: tensor([[10.3777, 0.4256, -1.2044, -0.3092, -5.3482], [ 0.4256, 2.3792, -0.6797, -2.5855, -1.5214], [-1.2044, -0.6797, 6.9600, 0.7112, -0.6421], [-0.3092, -2.5855, 0.7112, 2.9499, 1.5245], [-5.3482, -1.5214, -0.6421, 1.5245, 4.0248]]) tensor([[10.3777, 0.4256, -1.2044, -0.3092, -5.3482], [ 0.4256, 2.3792, -0.6797, -2.5856, -1.5214], [-1.2044, -0.6797, 6.9600, 0.7112, -0.6421], [-0.3092, -2.5856, 0.7112, 2.9500, 1.5246], [-5.3482, -1.5214, -0.6421, 1.5246, 4.0248]]) =============== test for agreement on scale_tril: tensor([[ 3.2214, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.1321, 1.5368, 0.0000, 0.0000, 0.0000], [-0.3739, -0.4102, 2.5791, 0.0000, 0.0000], [-0.0960, -1.6742, -0.0044, 0.3713, 0.0000], [-1.6602, -0.8473, -0.6244, -0.1508, 0.3716]]) tensor([[ 3.2214, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.1321, 1.5368, 0.0000, 0.0000, 0.0000], [-0.3739, -0.4102, 2.5791, 0.0000, 0.0000], [-0.0960, -1.6742, -0.0044, 0.3713, 0.0000], [-1.6602, -0.8473, -0.6244, -0.1508, 0.3716]])