API
Interface
ChangesOfVariables.with_logabsdet_jacobian
— Functionwith_logabsdet_jacobian(f, x)
Computes both the transformed value of x
under the transformation f
and the logarithm of the volume element.
For (y, ladj) = with_logabsdet_jacobian(f, x)
, the following must hold true:
y == f(x)
ladj
is thelog(abs(det(jacobian(f, x))))
with_logabsdet_jacobian
comes with support for broadcasted/mapped functions (via Base.Broadcast.BroadcastFunction
or Base.Fix1
) and ComposedFunction
.
If no volume element is defined/applicable, with_logabsdet_jacobian(f::F, x::T)
returns NoLogAbsDetJacobian{F,T}()
.
Examples
using ChangesOfVariables
foo(x) = inv(exp(-x) + 1)
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(foo), x)
y = foo(x)
ladj = -x + 2 * log(y)
(y, ladj)
end
x = 4.2
y, ladj_y = with_logabsdet_jacobian(foo, x)
using LinearAlgebra, ForwardDiff
y == foo(x) && ladj_y ≈ log(abs(ForwardDiff.derivative(foo, x)))
# output
true
X = rand(10)
broadcasted_foo = if VERSION >= v"1.6"
Base.Broadcast.BroadcastFunction(foo)
else
Base.Fix1(broadcast, foo)
end
Y, ladj_Y = with_logabsdet_jacobian(broadcasted_foo, X)
Y == broadcasted_foo(X) && ladj_Y ≈ logabsdet(ForwardDiff.jacobian(broadcasted_foo, X))[1]
# output
true
VERSION < v"1.6" || begin # Support for ∘ requires Julia >= v1.6
z, ladj_z = with_logabsdet_jacobian(log ∘ foo, x)
z == log(foo(x)) && ladj_z == ladj_y + with_logabsdet_jacobian(log, y)[2]
end
# output
true
Implementations of withlogabsdetjacobian can be tested (as a Test.@testset
) using ChangesOfVariables.test_with_logabsdet_jacobian
.
ChangesOfVariables.NoLogAbsDetJacobian
— Typestruct NoLogAbsDetJacobian{F,T}
An instance NoLogAbsDetJacobian{F,T}()
signifies that with_logabsdet_jacobian(::F, ::T)
is not defined.
ChangesOfVariables.setladj
— Functionsetladj(f, ladjf)::Function
Return a function that behaves like f
in general and which has with_logabsdet_jacobian(f, x) = f(x), ladjf(x)
.
Useful in cases where with_logabsdet_jacobian
is not defined for f
, or if f
needs to be assigned a LADJ-calculation that is only valid within a given context, e.g. only for a limited argument type/range that is guaranteed by the use case but not in general, or that is optimized to a custom use case.
For example, CUDA.CuArray
has no with_logabsdet_jacobian
defined, but may be used to switch computing device for a part of a heterogenous computing function chain. Likewise, one may want to switch numerical precision for a part of a calculation.
The function (wrapper) returned by setladj
supports InverseFunctions.inverse
if f
does so.
Example:
VERSION < v"1.6" || begin # Support for ∘ requires Julia >= v1.6
# Increases precition before calculation exp:
foo = exp ∘ setladj(setinverse(Float64, Float32), _ -> 0)
# A log-value from some low-precision (e.g. GPU) computation:
log_x = Float32(100)
# f(log_x) would return Inf32 without going to Float64:
y, ladj = with_logabsdet_jacobian(foo, log_x)
r_log_x, ladj_inv = with_logabsdet_jacobian(inverse(foo), y)
ladj ≈ 100 ≈ -ladj_inv && r_log_x ≈ log_x
end
# output
true
Test utility
ChangesOfVariables.test_with_logabsdet_jacobian
— FunctionChangesOfVariables.test_with_logabsdet_jacobian(f, x, getjacobian; compare = isapprox, kwargs...)
Test if with_logabsdet_jacobian(f, x)
is implemented correctly.
Checks if the result of with_logabsdet_jacobian(f, x)
is approximately equal to (f(x), logabsdet(getjacobian(f, x)))
So the test uses getjacobian(f, x)
to calculate a reference Jacobian for f
at x
. Passing ForwardDiff.jabobian
, Zygote.jacobian
or similar as the getjacobian
function will do fine in most cases. If input and output of f
are real scalar values, use ForwardDiff.derivative
.
Note that the result of getjacobian(f, x)
must be a real-valued matrix or a real scalar, so you may need to use a custom getjacobian
function that transforms the shape of x
and f(x)
internally, in conjunction with automatic differentiation.
kwargs...
are forwarded to compare
.
!!! Note On Julia >= 1.9, you have to load the Test
standard library to be able to use this function.
Additional functionality
ChangesOfVariables.FunctionWithLADJ
— Typestruct FunctionWithLADJ{F,LADJF} <: Function
A function with an separate function to compute it's logabddet(J)
.
Do not construct directly, use setladj(f, ladjf)
instead.