"""
Recombination — the ARG topology change at a single recombination breakpoint.
Mirrors Recombination.cpp / Recombination.hpp.
A Recombination stores the set of branches deleted and inserted at a
genomic position `pos`. From the deleted/inserted sets it derives the
key named branches (source, target, merging, recombined, transfer branches)
and the associated times (start_time). These are used by the BSP/TSP
transfer steps and by the MCMC threader.
"""
from __future__ import annotations
import sys
from typing import Optional, Set, TYPE_CHECKING
if TYPE_CHECKING:
from .branch import Branch
from .node import Node
# We import Branch lazily inside methods where needed to avoid circular
# imports at module load time.
[docs]
class Recombination:
"""Topology change record for one recombination breakpoint.
Attributes
----------
pos: Genomic position of the breakpoint (set externally).
deleted_branches: Branches that exist *before* (left of) `pos`.
inserted_branches:Branches that exist *after* (right of) `pos`.
Derived attributes (computed by find_recomb_info):
source_branch: The lineage that recombines.
target_branch: The lineage it joins at coalescence.
inserted_node: New internal node created at the coalescence.
deleted_node: Old internal node removed by the topology change.
merging_branch: The branch that takes over after the removed node.
recombined_branch:Part of the source lineage below start_time.
start_time: Height at which recombination begins.
lower/upper_transfer_branch: For BSP interval transfer.
source_sister_branch, source_parent_branch: auxiliary.
"""
def __init__(
self,
deleted_branches: Optional[Set] = None,
inserted_branches: Optional[Set] = None,
) -> None:
from .branch import Branch
self.pos: float = 0.0
self.deleted_branches: Set[Branch] = (
set(deleted_branches) if deleted_branches else set()
)
self.inserted_branches: Set[Branch] = (
set(inserted_branches) if inserted_branches else set()
)
# Derived by find_nodes() / find_recomb_info()
self.deleted_node: Optional["Node"] = None
self.inserted_node: Optional["Node"] = None
self.source_branch: "Branch" = Branch()
self.target_branch: "Branch" = Branch()
self.merging_branch: "Branch" = Branch()
self.recombined_branch: "Branch" = Branch()
self.source_sister_branch: "Branch" = Branch()
self.source_parent_branch: "Branch" = Branch()
self.lower_transfer_branch: "Branch" = Branch()
self.upper_transfer_branch: "Branch" = Branch()
self.start_time: float = 0.0
if self.deleted_branches or self.inserted_branches:
self._simplify_branches()
self._find_nodes()
[docs]
def set_pos(self, x: float) -> None:
self.pos = x
# ------------------------------------------------------------------
# Query
# ------------------------------------------------------------------
[docs]
def affect(self, b: "Branch") -> bool:
"""True iff *b* is in deleted_branches."""
return b in self.deleted_branches
[docs]
def create(self, b: "Branch") -> bool:
"""True iff *b* is in inserted_branches."""
return b in self.inserted_branches
# ------------------------------------------------------------------
# Branch-tracing (used by ARG.remove / ARG.add)
# ------------------------------------------------------------------
[docs]
def trace_forward(self, t: float, curr_branch: "Branch") -> "Branch":
"""Return the branch at time *t* that *curr_branch* maps to after pos.
Mirrors Recombination::trace_forward.
"""
from .branch import Branch
_NULL = Branch()
_INT_MAX = sys.maxsize
if self.pos == 0 or self.pos == _INT_MAX:
return _NULL
if not self.affect(curr_branch):
return curr_branch
if curr_branch == self.source_branch:
if t >= self.start_time:
return _NULL
else:
return self.recombined_branch
elif curr_branch == self.target_branch:
if t > self.inserted_node.time:
return self.upper_transfer_branch
else:
return self.lower_transfer_branch
else:
return self.merging_branch
[docs]
def trace_backward(self, t: float, curr_branch: "Branch") -> "Branch":
"""Return the branch that maps to *curr_branch* before pos.
Mirrors Recombination::trace_backward.
"""
from .branch import Branch
_NULL = Branch()
if not self.deleted_branches:
return _NULL
if not self.create(curr_branch):
return curr_branch
if curr_branch == self.recombined_branch:
if t >= self.start_time:
return _NULL
else:
return self.source_branch
elif curr_branch != self.merging_branch:
return self.target_branch
else:
if t > self.deleted_node.time:
return self._search_lower_node(self.deleted_node)
else:
return self._search_upper_node(self.deleted_node)
# ------------------------------------------------------------------
# Mutation remove/add helpers (called by ARG.remove / ARG.add)
# ------------------------------------------------------------------
[docs]
def remove(
self,
prev_removed_branch: "Branch",
next_removed_branch: "Branch",
prev_split_branch: "Branch",
next_split_branch: "Branch",
cut_node: Optional["Node"] = None,
) -> None:
"""Update this recombination when the surrounding topology is pruned.
Mirrors Recombination::remove (both overloads).
"""
from .branch import Branch
_NULL = Branch()
if not self.deleted_branches and not self.inserted_branches:
return
if cut_node is not None:
# Overload with cut_node
if prev_removed_branch == next_removed_branch and \
prev_split_branch == next_split_branch:
return
if prev_removed_branch == _NULL:
self._break_front(next_removed_branch, next_split_branch, cut_node)
return
elif next_removed_branch == _NULL:
self._break_end(prev_removed_branch, prev_split_branch, cut_node)
return
self._add_deleted(prev_split_branch)
self._add_deleted(next_removed_branch)
self._add_deleted(Branch(next_split_branch.lower_node, next_removed_branch.upper_node))
self._add_deleted(Branch(next_removed_branch.upper_node, next_split_branch.upper_node))
self._add_inserted(next_split_branch)
self._add_inserted(prev_removed_branch)
self._add_inserted(Branch(prev_split_branch.lower_node, prev_removed_branch.upper_node))
self._add_inserted(Branch(prev_removed_branch.upper_node, prev_split_branch.upper_node))
self._add_deleted(Branch(prev_removed_branch.lower_node, cut_node))
self._add_inserted(Branch(next_removed_branch.lower_node, cut_node))
self._simplify_branches()
# Fix source_branch if it was destroyed
self._fix_source_after_remove(prev_split_branch, prev_removed_branch)
self._find_nodes()
self._find_target_branch()
if self.deleted_branches:
self._find_recomb_info()
else:
# Overload without cut_node
self._add_deleted(prev_split_branch)
self._add_deleted(next_removed_branch)
self._add_deleted(Branch(next_split_branch.lower_node, next_removed_branch.upper_node))
self._add_deleted(Branch(next_removed_branch.upper_node, next_split_branch.upper_node))
self._add_inserted(next_split_branch)
self._add_inserted(prev_removed_branch)
self._add_inserted(Branch(prev_split_branch.lower_node, prev_removed_branch.upper_node))
self._add_inserted(Branch(prev_removed_branch.upper_node, prev_split_branch.upper_node))
self._simplify_branches()
if not self.deleted_branches and not self.inserted_branches:
return
self._fix_source_after_remove(prev_split_branch, prev_removed_branch)
self._find_nodes()
self._find_target_branch()
self._find_recomb_info()
[docs]
def add(
self,
prev_added_branch: "Branch",
next_added_branch: "Branch",
prev_joining_branch: "Branch",
next_joining_branch: "Branch",
cut_node: Optional["Node"] = None,
) -> None:
"""Update this recombination when a new lineage is threaded in.
Mirrors Recombination::add.
"""
from .branch import Branch
_NULL = Branch()
if prev_added_branch == next_added_branch and \
prev_joining_branch == next_joining_branch:
return
if next_added_branch != _NULL:
self._add_inserted(next_added_branch)
self._add_inserted(Branch(next_joining_branch.lower_node, next_added_branch.upper_node))
self._add_inserted(Branch(next_added_branch.upper_node, next_joining_branch.upper_node))
self._add_deleted(next_joining_branch)
if cut_node is not None:
self._add_deleted(Branch(next_added_branch.lower_node, cut_node))
if prev_added_branch != _NULL:
self._add_deleted(prev_added_branch)
self._add_deleted(Branch(prev_joining_branch.lower_node, prev_added_branch.upper_node))
self._add_deleted(Branch(prev_added_branch.upper_node, prev_joining_branch.upper_node))
self._add_inserted(prev_joining_branch)
if cut_node is not None:
self._add_inserted(Branch(prev_added_branch.lower_node, cut_node))
self._simplify_branches()
if self.pos == 0:
return
if not self.deleted_branches:
return
self._find_nodes()
# Update source_branch — mirrors C++ Recombination::add (lines 200-208):
# if (prev_joining == source_branch) {
# if (prev_added.upper == next_added.upper) // pointer equality
# source_branch = Branch(prev_added.upper, source_branch.upper);
# else
# source_branch = Branch(source_branch.lower, prev_added.upper);
# } else { source_branch = search_lower_node(source_branch.lower); }
if prev_joining_branch == self.source_branch:
if prev_added_branch.upper_node is next_added_branch.upper_node:
self.source_branch = Branch(prev_added_branch.upper_node, self.source_branch.upper_node)
else:
self.source_branch = Branch(self.source_branch.lower_node, prev_added_branch.upper_node)
else:
self.source_branch = self._search_lower_node(self.source_branch.lower_node)
self._find_target_branch()
self._find_recomb_info()
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _fix_source_after_remove(
self,
prev_split_branch: "Branch",
prev_removed_branch: "Branch",
) -> None:
from .branch import Branch
destroyed1 = Branch(prev_split_branch.lower_node, prev_removed_branch.upper_node)
destroyed2 = Branch(prev_removed_branch.upper_node, prev_split_branch.upper_node)
if self.source_branch == destroyed1 or self.source_branch == destroyed2:
self.source_branch = prev_split_branch
def _break_front(
self,
next_removed: "Branch",
next_split: "Branch",
cut_node: "Node",
) -> None:
from .branch import Branch
self._add_deleted(next_removed)
self._add_deleted(Branch(next_split.lower_node, next_removed.upper_node))
self._add_deleted(Branch(next_removed.upper_node, next_split.upper_node))
self._add_inserted(next_split)
self._add_inserted(Branch(next_removed.lower_node, cut_node))
self._simplify_branches()
def _break_end(
self,
prev_removed: "Branch",
prev_split: "Branch",
cut_node: "Node",
) -> None:
from .branch import Branch
self._add_inserted(prev_removed)
self._add_inserted(Branch(prev_split.lower_node, prev_removed.upper_node))
self._add_inserted(Branch(prev_removed.upper_node, prev_split.upper_node))
self._add_deleted(prev_split)
self._add_deleted(Branch(prev_removed.lower_node, cut_node))
self._simplify_branches()
def _simplify_branches(self) -> None:
"""Remove branches that appear in both deleted and inserted sets."""
common = self.deleted_branches & self.inserted_branches
self.deleted_branches -= common
self.inserted_branches -= common
def _add_deleted(self, b: "Branch") -> None:
if b.upper_node is not None and b.lower_node is not None:
self.deleted_branches.add(b)
def _add_inserted(self, b: "Branch") -> None:
if b.upper_node is not None and b.lower_node is not None:
self.inserted_branches.add(b)
def _find_nodes(self) -> None:
"""Identify inserted_node and deleted_node from branch sets."""
prev_upper = {b.upper_node for b in self.deleted_branches}
next_upper = {b.upper_node for b in self.inserted_branches}
for n in prev_upper:
if n not in next_upper:
self.deleted_node = n
for n in next_upper:
if n not in prev_upper:
self.inserted_node = n
def _find_target_branch(self) -> None:
"""Find the branch that the recombined lineage joins."""
from .branch import Branch
if self.inserted_node is None:
self.target_branch = Branch()
return
t = self.inserted_node.time
for b in self.deleted_branches:
if b == self.source_branch:
continue
if b.lower_node.time > t or b.upper_node.time < t:
continue
lower = Branch(b.lower_node, self.inserted_node)
upper = Branch(self.inserted_node, b.upper_node)
if lower in self.inserted_branches and upper in self.inserted_branches:
self.target_branch = b
return
# Fallback: accept partial match
for b in self.deleted_branches:
if b == self.source_branch:
continue
if b.lower_node.time > t or b.upper_node.time < t:
continue
lower = Branch(b.lower_node, self.inserted_node)
upper = Branch(self.inserted_node, b.upper_node)
if lower in self.inserted_branches or upper in self.inserted_branches:
self.target_branch = b
return
self.target_branch = Branch()
def _find_recomb_info(self) -> None:
"""Compute start_time, merging_branch, transfer branches, etc."""
from .branch import Branch
_INT_MAX = sys.maxsize
if self.pos == 0 or self.pos == _INT_MAX:
return
dn = self.deleted_node
l = None
u = None
for b in self.deleted_branches:
if b == self.source_branch:
continue
if b.upper_node is dn:
l = self.inserted_node if b == self.target_branch else b.lower_node
elif b.lower_node is dn:
u = self.inserted_node if b == self.target_branch else b.upper_node
self.merging_branch = Branch(l, u)
self.recombined_branch = Branch(
self.source_branch.lower_node, self.inserted_node
)
self.source_sister_branch = self._search_upper_node(dn)
self.source_parent_branch = self._search_lower_node(dn)
# Transfer branches
candidate_lower = Branch(self.target_branch.lower_node, self.inserted_node)
if candidate_lower in self.inserted_branches:
self.lower_transfer_branch = candidate_lower
else:
self.lower_transfer_branch = self.merging_branch
candidate_upper = Branch(self.inserted_node, self.target_branch.upper_node)
if candidate_upper in self.inserted_branches:
self.upper_transfer_branch = candidate_upper
else:
self.upper_transfer_branch = self.merging_branch
# start_time is set by approx_sample_recombinations / sample_recombination,
# NOT here — mirroring C++ find_recomb_info which does not touch start_time.
def _search_upper_node(self, n: "Node") -> "Branch":
"""Find the deleted branch (not source) with upper_node == n."""
from .branch import Branch
for b in self.deleted_branches:
if b != self.source_branch and b.upper_node is n:
return b
return Branch()
def _search_lower_node(self, n: "Node") -> "Branch":
"""Find the deleted branch with lower_node == n."""
from .branch import Branch
for b in self.deleted_branches:
if b.lower_node is n:
return b
return Branch()
def __repr__(self) -> str:
return (
f"Recombination(pos={self.pos}, "
f"|del|={len(self.deleted_branches)}, "
f"|ins|={len(self.inserted_branches)})"
)