Neural KL divergence estimation (Donsker-Varadhan representation) using torch
Source: R/kld-estimation-neural.R
kld_est_neural.Rd
Estimation of KL divergence between continuous distributions based on the
Donsker-Varadhan representation
$$D_{KL}(P||Q) = \sup_{f} E_P[f(X)] - \log\left(E_Q[e^{f(X)}]\right)$$
using Monte Carlo averages to approximate the expectations, and optimizing
over a class of neural networks. The torch
package is required to use this
function.
Usage
kld_est_neural(
X,
Y,
d_hidden = 1024,
learning_rate = 1e-04,
epochs = 5000,
device = c("cpu", "cuda", "mps"),
verbose = FALSE
)
Arguments
- X, Y
n
-by-d
andm
-by-d
numeric matrices, representingn
samples from the true distribution \(P\) andm
samples from the approximate distribution \(Q\), both ind
dimensions. Vector input is treated as a column matrix.Number of nodes in hidden layer (default:
32
)- learning_rate
Learning rate during gradient descent (default:
1e-4
)- epochs
Number of training epochs (default:
200
)- device
Calculation device, either
"cpu"
(default),"cuda"
or"mps"
.- verbose
Generate progress report to consolue during training of the neutral network (default:
FALSE
)?
Details
Disclaimer: this is a simple test implementation which is not optimized by any means. In particular:
it uses a fully connected network with (only) a single hidden layer
it uses standard gradient descient on the full dataset and not more advanced estimators Also, he syntax is likely to change in the future.
Estimation is done as described for mutual information in Belghazi et al. (see ref. below), except that standard gradient descent is used on the full samples X and Y instead of using batches. Indeed, in the case where X and Y have a different length, batch sampling is not that straightforward.
Reference: Belghazi et al., Mutual Information Neural Estimation, PMLR 80:531-540, 2018.
Examples
# 2D example
# analytical solution
kld_gaussian(mu1 = rep(0,2), sigma1 = diag(2),
mu2 = rep(0,2), sigma2 = matrix(c(1,1,1,2),nrow=2))
#> [1] 0.5
# sample generation
set.seed(0)
nxy <- 1000
X1 <- rnorm(nxy)
X2 <- rnorm(nxy)
Y1 <- rnorm(nxy)
Y2 <- Y1 + rnorm(nxy)
X <- cbind(X1,X2)
Y <- cbind(Y1,Y2)
# Estimation
kld_est_nn(X, Y)
#> [1] 0.2610792
if (FALSE) { # \dontrun{
# requires the torch package and takes ~1 min
kld_est_neural(X, Y)
} # }