API Reference
Main API
MuSink.Problem
— TypeMuSink problem type that defines the fixed parts of an UMOT problem.
Contains the target and reference measures as well as the cost topology and marginal penalty.
MuSink.Problem
— MethodProblem(targets [; cost, penalty, references, reference_mass])
Default constructor of a MuSink problem.
The arguments targets
and references
are expected to be of type Dict{Node, Array{Float64, 3}}
, assigning target and reference measures to each node of the tree. If a reference_mass
is specified, the references will be scaled such that their product measure has mass reference_mass
.
MuSink.Workspace
— TypeCore structure for solving UMOT problems. Besides the static properties provided by a Problem
(tree topology, target and reference measures, cost and penalty functions), a Workspace
also stores various parameters and auxiliary variables of the optimization problem.
Workspaces are powered by fast implementations of the inner loop of Sinkhorn iterations. These iterations can be performed either in the exponential domain (fast, but breaks down for small regularization parameters eps
) or the logarithmic domain (slower, but remains stable for small eps
). The domain can be specified via the keyword argument logdomain
. The default is logdomain = true
.
Sinkhorn steps and update rules
Sinkhorn steps that modify the Workspace
(or, more precisely, the potentials and auxiliary α-arrays stored in the workspace) can conveniently be performed via the step!
or steps!
function. Based on the array type of the workspace, the properties of the cost function (separable or not), and the operating domain (logdomain or expdomain), specialized kernels are used.
Since the order of the updates can influence the stability / convergence of the algorithm, several different update rules (accessible via the keyword argument stepmode
) are implemented:
stepmode = :legacy
: a backwards pass through the tree (i.e., updating the value of α-arrays from parents to leaves) is followed by a forward pass (i.e., updating the value of α-arrays from leaves to parents) that also updates the potentials. This update rule is proposed by the authors of the original UMOT manuscript and seems to work fine in most configurations. However, in certain situations, the update rule blocks the propagation of information through the tree, such that potential updates within a step remain unaware of one another. A possible consequence is mass oscillation, which prevents convergence. This behavior is explicit in the barycenter problem (e.g., a star-shaped tree topology where the center node has fixed potentials viarho = 0
) IF the root node is set to the center. If the root node is set to a leaf of the tree,stepmode = :legacy
should always work.stepmode = :stable
: a backwards pass through the tree is followed by a mixed forward / backward pass that updates the potentials. This update rule fixes the issues of the:legacy
method but also requires more calculations per step.stepmode = :alternate
: only a mixed forward / backward pass that updates the potentials is conducted. This update rule uses as many computations per step as the:legacy
method (i.e., it is faster than the:stable
method) but should be protected against harmful information barriers. Still, this method is more aggressive and might thus be more susceptible to instabilities than:stable
.stepmode = :symmetric
: a backwards pass is followed by a forward pass. Both of these passes update the values of α only. Afterwards, all potentials are updated simultaneously. This update rule is severely broken for more than two measures and is subject to change or deprecation
Parameters
The UMOT problem relies on several parameters to be chosen by the user, like the parameters eps
(strength of the entropic regularization), reach
(the interaction radius), or rho
(strength of the marginal penalty). The latter can be set per node. It is additionally possible to specify the weight of individual edges of the tree, which corresponds to a scaling of the cost function along that edge.
MuSink.Workspace
— MethodWorkspace(problem::Problem; <keyword arguments>)
Create a workspace in which problem
can be solved by calling step!
or converge!
.
All optional arguments arg
can also be set via set_[arg]!(ws, ...)
after construction.
Arguments
eps = 1
. The Sinkhorn scaling parameter.rho = Inf
. The default value of the marginal penalty parameterrho
.rhos
. Dictionary that maps nodes torho
values.weights
. Dictionary that maps edges to values that weight the cost.reach = Inf
. Reach parameter (radius of maximally allowed transport).logdomain = true
. Whether to work in the log or exp domain.stepmode = :stable
. Strategy that determines the order of potential updates. See the documentation ofWorkspace
for details.atype = :array64
. Array type used in the backend.
MuSink.Chain
— FunctionChain(targets; kwargs...)
Directly create a workspace that puts targets
in a linear cost tree.
Valid keyword arguments are the ones for Problem
(except for root
) and Workspace
.
MuSink.Barycenter
— FunctionBarycenter(targets; kwargs...)
Directly create a workspace that puts targets
in a star-shaped cost tree around a barycenter-node with rho = 0
. Valid keyword arguments are the ones for Problem
(except for root
) and Workspace
.
MuSink.step!
— Functionstep!(ws::Workspace, steps = 1; max_time = Inf, stepmode = ws.stepmode)
Perform steps
Sinkhorn update steps with a given stepmode
.
The argument max_time
denotes a time limit in seconds, after which no more steps are performed.
MuSink.converge!
— Functionconverge!(workspace; <keyword arguments>)
Auxiliary method that implements epsilon scaling to approximate the unregularized problem.
Arguments:
start_eps = nothing
. Initial value of epsilon. By default, the currentepsilon
ofworkspace
is used.target_eps = nothing
. Target value of epsilon. By default, determined indirectly viatarget_blur
.target_blur = nothing
. Target value of the average blur of the transport plan. Either this value ortarget_eps
should be passed explicitly.scaling = 0.9
. Factor by which epsilon is reduced per scaling step.tolerance = 1e-2
. Convergence tolerance to be achieved beforeepsilon
is scaled down.target_tolerance = tolerance
. Convergence tolerance for the final scaling step.max_time = 600
. Maximal time in seconds that the function is allowed to run.max_steps = 10000
. Maximal number of steps that the function is allowed to conduct.callback
. Procedure that is called after each Sinkhorn step.callback_scaling
. Procedure that is called after each (partial) convergence.verbose = false
. If true, the function documents its progress.
MuSink.marginal
— Functionmarginal(ws::Workspace, node; keep_batchdim = false)
The marginal measure at node
with respect to the counting measure.
If keep_batchdim = true
, the singular batch dimension in case of batchsize 1 is not dropped.
MuSink.target
— Functiontarget(p::Problem, a)
Returns the target measure of p
at node a
.
target(ws::Workspace, node; keep_batchdim = false)
The target measure at node
with respect to the counting measure.
If keep_batchdim = true
, the singular batch dimension in case of batchsize 1 is not dropped.
Couplings and transport
MuSink.Coupling
— TypeStructure that facilitates calculating (parts of) the transport plan between two nodes.
Since the full transport plan is often impractically large, this type provides a lazy interface that operates on the dual potentials.
Some of the operations on Coupling
s can only be implemented performantly when nodes are adjacent. These operations fail for couplings between non-neighboring nodes.
MuSink.transport
— Functiontransport(plan::Coupling, i, j; conditional = false)
transport(plan::Coupling, (i, j); conditional = false)
transport(plan::Coupling, is, js; conditional = false)
transport(plan::Coupling, (is, js); conditional = false)
transport(ws::Workspace, a, b, i, j; conditional = false)
transport(ws::Workspace, a, b, (i, j); conditional = false)
transport(ws::Workspace, a, b, is, js; conditional = false)
transport(ws::Workspace, a, b, (is, js); conditional = false)
Returns the evaluation of the coupling plan
at pixel (i,j)
of node a
. If iterables is
and js
are provided, a vector of transport arrays is returned.
If conditional = true
, the transport arrays sum to one.
If a workspace ws
as well as two nodes a
and b
are provided, the corresponding coupling is calculated implicitly.
MuSink.transport_window
— Functiontransport_window(plan::Coupling, i, j; conditional = false)
transport_window(plan::Coupling, (i, j); conditional = false)
transport_window(ws::Workspace, a, b, i, j; conditional = false)
transport_window(ws::Workspace, a, b, (i, j); conditional = false)
Like transport
but returns a window of radius reach
around the pixel posiiton (i, j)
.
Reductions
MuSink.Reductions.Reduction
— TypeReduction over a coupling.
This type performantly implements sums of the form
sum_{b} pi_{ab} * f(a-b)
where pi
is a coupling.
Since evaluation of f
takes place in the logdomain, it must return strictly positive values.
MuSink.Reductions.reduce
— Functionreduce(r::Reduction, plan; conditional)
reduce(r::Reduction, workspace, a, b; conditional)
reduction(plan::Coupling; conditional)
reduction(ws::Workspace, a, b; conditional)
Apply the reduction r
to a coupling plan
or a ws
between nodes a
and b
. If conditional = true
, the result is pointwisely divided by the marginal measure of node a
.
Predefined Reductions
MuSink.Reductions.ishift
— Functionishift(coupling)
ishift(workspace, a, b)
Pointwise mean shift in the first component of the transport plan from a
to b
.
MuSink.Reductions.jshift
— Functionjshift(coupling)
jshift(workspace, a, b)
Pointwise mean shift in the second component of the transport plan from a
to b
.
MuSink.Reductions.ishiftsq
— Functionishiftsq(coupling)
ishiftsq(workspace, a, b)
Pointwise mean squared shift in the first component of the transport plan from a
to b
.
MuSink.Reductions.jshiftsq
— Functionjshiftsq(coupling)
jshiftsq(workspace, a, b)
Pointwise mean squared shift in the second component of the transport plan from a
to b
.
MuSink.Reductions.ivar
— Functionivar(coupling)
ivar(workspace, a, b)
Pointwise variance of the first component of the transport plan from a
to b
.
MuSink.Reductions.jvar
— Functionjvar(coupling)
jvar(workspace, a, b)
Pointwise variance of the second component of the transport plan from a
to b
.
MuSink.Reductions.var
— Functionvar(coupling)
var(workspace, a, b)
Pointwise variance (both components) of the transport plan from a
to b
.
MuSink.Reductions.std
— Functionvar(coupling)
var(workspace, a, b)
Pointwise standard deviation (both components) of the transport plan from a
to b
.
MuSink.Reductions.imap
— Functionimap(coupling)
imap(workspace, a, b)
Pointwise mean position of the first component of the transport plan from a
to b
.
MuSink.Reductions.jmap
— Functionjmap(coupling)
jmap(workspace, a, b)
Pointwise mean position of the second component of the transport plan from a
to b
.
MuSink.Reductions.coloc
— Functioncoloc(coupling; threshold, conditional = false)
coloc(workspace, a, b; threshold, conditional = false)
Colocalization with cost threshold threshold
.
Trees
MuSink.Tree.Node
— TypeNode type. Carries information about children and its parent node. Has an associated index that can be used to identify it within a tree.
Note that no dedicated Tree
type exists. Instead, a root node, which has no parent, encodes the tree-associated structure.
MuSink.Tree.root
— Functionroot(a::Node)
Returns the root node of the tree that contains a
.
MuSink.Tree.parent
— Functionparent(a::Node)
Returns the parent node of a
. Returns nothing
is a
is the root node.
MuSink.Tree.children
— Functionchildren(a::Node)
Returns all children nodes of a
.
MuSink.Tree.descendants
— Functiondescendants(a::Node, include_node = true)
Returns all descendants of a
, including a
if include_node = true
is passed.
Querying workspaces
MuSink.get_eps
— Functionget_eps(ws::Workspace)
Returns the current value of eps
.
MuSink.get_rho
— Functionget_rho(ws::Workspace)
get_rho(ws::Workspace, a)
Returns the current value of rho
(default or at node a
).
MuSink.get_reach
— Functionget_reach(ws::Workspace)
Returns the current value of reach
.
MuSink.get_weight
— Functionget_weight(ws::Workspace, a, b)
Returns the cost weight between the nodes a
and b
.
MuSink.get_stepmode
— Functionget_stepmode(ws::Workspace)
Returns the stepmode of ws
.
MuSink.get_domain
— Functionget_domain(ws::Workspace)
Returns true
if ws
is in the logdomain and false
else.
MuSink.potential
— Functionpotential(ws::Workspace, node; keep_batchdim = false)
The potential at node
. Note that this function always returns the actual dual UMOT potential (i.e., the potential in the logdomain).
If keep_batchdim = true
, the singular batch dimension in case of batchsize 1 is not dropped.
MuSink.nodes
— Functionnodes(p::Problem)
Returns all nodes of the problem cost tree.
nodes(ws::Workspace)
Returns all nodes of the problem cost tree.
MuSink.edges
— Functionedges(p::Problem)
Returns all (bidirectional) edges of the problem cost tree.
edges(ws::Workspace)
Returns all (bidirectional) edges of the problem cost tree.
Modifying workspaces
MuSink.set_eps!
— Functionset_eps!(ws::Workspace, eps)
Set the value of epsilon to eps
. The logdomain potentials are kept.
MuSink.set_rho!
— Functionset_rho!(ws::Workspace, rho)
set_rho!(ws::Workspace, a, rho)
Set the default value of rho to rho
. If a node (or node index) a
is provided, only set the value for this specific node.
MuSink.set_reach!
— Functionset_reach!(ws::Workspace, reach::Integer)
Set the reach of ws
to reach
.
MuSink.set_weight!
— Functionset_weight!(w::Workspace, a, b, weight)
Set the cost weight of the edge between a
and b
to weight
MuSink.set_stepmode!
— Functionset_stepmode!(ws::Workspace, stepmode)
Set the step mode of ws
to stepmode
.
MuSink.set_domain!
— Functionset_domain!(ws::Workspace, logdomain::Bool)
If logdomain = true
, move ws
into the logdomain. If logdomain = false
, move ws
into the expdomain.
Remote Workspaces
MuSink.Remote.RemoteWorkspace
— TypeStructure that enables remote MuSink computations.
Remote workspaces can be constructed like regular ones (see Workspace
) and can then be initalized in a separate task / thread / worker via Remote.init
.
MuSink.Remote.init
— Functioninit(ws::RemoteWorkspace, signal = nothing)
Initialize the remote workspace ws
.
A channel signal
can be passed, which receives the value true
if initialization is successful.