"""
ARG — the Ancestral Recombination Graph, the central data structure.
Mirrors ARG.cpp / ARG.hpp.
An ARG is encoded as a sorted map of Recombination records keyed by
genomic position. A marginal tree at position x is obtained by
replaying all records from position 0 up to (and including) x.
The two main MCMC operations are:
remove(cut_point) — extract a single lineage from the ARG, leaving
the remaining haplotypes connected.
add(joining, added) — thread the lineage back in at new positions.
"""
from __future__ import annotations
import math
import sys
from typing import Dict, List, Optional, Set, Tuple
import numpy as np
from sortedcontainers import SortedDict
from .branch import Branch
from .node import Node
from .recombination import Recombination
from .tree import Tree
_INT_MAX = sys.maxsize
def _choose_time(lb: float, ub: float) -> float:
"""Midpoint in log space; falls back to linear for small intervals.
Mirrors RSP_smc::choose_time(double lb, double ub).
"""
if ub - lb < 0.01:
return 0.5 * (lb + ub)
mt = math.log(0.5 * (math.exp(lb) + math.exp(ub)))
return max(lb, min(ub, mt))
[docs]
class ARG:
"""Ancestral Recombination Graph.
Attributes
----------
Ne: Effective population size (diploid).
sequence_length: Length of the genomic region (base pairs).
root: Sentinel root node (time=inf, index=-1).
sample_nodes: Set of leaf (sample) nodes.
node_set: All non-root nodes.
recombinations: SortedDict[pos → Recombination]. Always has
sentinels at 0 and INT_MAX.
mutation_sites: Sorted set of positions carrying derived alleles.
mutation_branches: pos → set of Branches carrying the mutation.
coordinates: Grid of genomic positions (HMM bins).
rhos: Scaled recombination rate per bin.
thetas: Scaled mutation rate per bin.
MCMC working variables (reset by clear_remove_info):
removed_branches: pos → Branch that was removed.
joining_branches: pos → Branch that "jumped over" the cut.
cut_tree: Marginal tree at the cut position.
cut_pos, cut_node: Position and sentinel node for the cut.
start, end: Genomic extent of the current MCMC window.
"""
def __init__(self, Ne: float = 1.0, sequence_length: float = 1.0) -> None:
self.Ne = Ne
self.sequence_length = sequence_length
# Root sentinel: time=inf, index=-1
self.root = Node(time=math.inf, index=-1)
self.sample_nodes: Set[Node] = set()
self.node_set: Set[Node] = set()
# Recombination records — always have sentinels at 0 and INT_MAX
self.recombinations: SortedDict = SortedDict()
r0 = Recombination()
r0.set_pos(0.0)
self.recombinations[0.0] = r0
r_end = Recombination()
r_end.set_pos(float(_INT_MAX))
self.recombinations[float(_INT_MAX)] = r_end
# Mutation data
self.mutation_sites: SortedDict = SortedDict() # pos → True (sorted set)
self.mutation_sites[float(_INT_MAX)] = True
self.mutation_branches: Dict[float, Set[Branch]] = {
float(_INT_MAX): set()
}
# HMM grid
self.coordinates: List[float] = []
self.rhos: List[float] = []
self.thetas: List[float] = []
self.bin_size: float = 1.0
self.bin_num: int = 0
# MCMC working state
self.removed_branches: SortedDict = SortedDict() # pos → Branch
self.joining_branches: SortedDict = SortedDict() # pos → Branch
self.cut_tree: Tree = Tree()
self.start_tree: Tree = Tree()
self.end_tree: Tree = Tree()
self.cut_pos: float = 0.0
self.cut_time: float = 0.0
self.cut_node: Optional[Node] = None
self.start: float = 0.0
self.end: float = sequence_length
# RNG (set by Sampler)
self.rng: Optional[np.random.Generator] = None
# ------------------------------------------------------------------
# Random helper
# ------------------------------------------------------------------
def _random(self) -> float:
if self.rng is not None:
return float(self.rng.uniform())
return float(np.random.uniform())
# ------------------------------------------------------------------
# Initialisation helpers
# ------------------------------------------------------------------
[docs]
def add_sample(self, n: Node) -> None:
"""Register sample node *n* and update mutation_sites."""
self.sample_nodes.add(n)
for pos in n.mutation_sites:
if pos != -1:
self.mutation_sites[pos] = True
self.removed_branches = SortedDict()
self.removed_branches[0.0] = Branch(n, self.root)
self.removed_branches[self.sequence_length] = Branch()
self.start_tree = self.get_tree_at(0.0)
self.cut_pos = 0.0
self.start = 0.0
self.end = self.sequence_length
[docs]
def build_singleton_arg(self, n: Node) -> None:
"""Build an ARG containing a single sample node *n*."""
self.add_sample(n)
branch = Branch(n, self.root)
r0 = Recombination(set(), {branch})
r0.set_pos(0.0)
self.recombinations[0.0] = r0
for pos in list(self.mutation_sites.keys()):
self.mutation_branches[pos] = {branch}
[docs]
def add_node(self, n: Node) -> None:
if n is not self.root and n is not None:
self.node_set.add(n)
# ------------------------------------------------------------------
# Grid / rate construction
# ------------------------------------------------------------------
[docs]
def discretize(self, bin_size: float) -> None:
"""Build coordinate grid, placing breakpoints at recombinations.
Mirrors ARG::discretize.
"""
self.bin_size = bin_size
self.coordinates = []
recomb_keys = list(self.recombinations.keys())
recomb_idx = 1 # skip the sentinel at 0
curr_pos = 0.0
while curr_pos < self.sequence_length:
self.coordinates.append(curr_pos)
next_recomb = recomb_keys[recomb_idx] if recomb_idx < len(recomb_keys) else _INT_MAX
if next_recomb < curr_pos + bin_size:
curr_pos = next_recomb
recomb_idx += 1
else:
curr_pos = min(curr_pos + bin_size, self.sequence_length)
self.coordinates.append(self.sequence_length)
self.bin_num = len(self.coordinates) - 1
[docs]
def get_index(self, x: float) -> int:
"""Return index i such that coordinates[i] <= x < coordinates[i+1]."""
import bisect
idx = bisect.bisect_right(self.coordinates, x) - 1
return max(0, idx)
[docs]
def compute_rhos_thetas(self, r: float, m: float) -> None:
"""Compute per-bin scaled recombination/mutation rates.
Mirrors ARG::compute_rhos_thetas(double r, double m).
"""
n = len(self.coordinates) - 1
self.rhos = []
self.thetas = []
for i in range(n):
span = self.coordinates[i + 1] - self.coordinates[i]
self.rhos.append(r * span)
self.thetas.append(m * span)
# ------------------------------------------------------------------
# Tree access
# ------------------------------------------------------------------
[docs]
def get_tree_at(self, x: float) -> Tree:
"""Return the marginal tree at position *x*.
Replays all Recombination records with pos <= x.
"""
tree = Tree()
for pos, r in self.recombinations.items():
if pos <= x:
tree.forward_update(r)
else:
break
return tree
[docs]
def get_query_node_at(self, x: float) -> Optional[Node]:
"""Return the query node (lower_node of the removed branch at x)."""
idx = self.removed_branches.bisect_right(x) - 1
if idx < 0:
return None
key = self.removed_branches.keys()[idx]
b = self.removed_branches[key]
return b.lower_node if b else None
# ------------------------------------------------------------------
# MCMC: remove
# ------------------------------------------------------------------
[docs]
def remove(self, cut_point: Tuple[float, Branch, float]) -> None:
"""Remove a lineage from the ARG.
cut_point = (pos, center_branch, cut_time).
After this call:
- self.removed_branches maps positions to the removed branch.
- self.joining_branches maps positions to the "joining" branch.
- self.start / self.end delimit the affected genomic region.
- self.start_tree is the marginal tree at start (minus the cut).
"""
pos, center_branch, t = cut_point
self.cut_time = t
self.cut_node = Node(time=t)
self.cut_node.index = -2
self.removed_branches = SortedDict()
self.joining_branches = SortedDict()
forward_tree = self.cut_tree.copy()
backward_tree = self.cut_tree.copy()
# ---- forward pass: trace the removed branch to the right ----
f_it_idx = self.recombinations.bisect_right(pos)
keys = list(self.recombinations.keys())
prev_joining = Branch()
prev_removed = center_branch
next_removed = center_branch
while next_removed:
if f_it_idx >= len(keys):
break
r_pos = keys[f_it_idx]
r = self.recombinations[r_pos]
prev_joining = forward_tree.find_joining_branch(prev_removed)
forward_tree.forward_update(r)
next_removed = r.trace_forward(t, prev_removed)
if next_removed and next_removed.upper_node is self.root:
next_removed = Branch()
next_joining = forward_tree.find_joining_branch(next_removed)
r.remove(prev_removed, next_removed, prev_joining, next_joining, self.cut_node)
store_pos = min(r_pos, self.sequence_length)
self.removed_branches[store_pos] = next_removed
self.joining_branches[store_pos] = next_joining
f_it_idx += 1
prev_removed = next_removed
# ---- backward pass: trace the removed branch to the left ----
b_it_idx = self.recombinations.bisect_right(pos) - 1
next_removed = center_branch
prev_removed = center_branch
while prev_removed:
if b_it_idx < 0:
break
r_pos = keys[b_it_idx]
r = self.recombinations[r_pos]
self.removed_branches[r_pos] = prev_removed
next_joining = backward_tree.find_joining_branch(next_removed)
self.joining_branches[r_pos] = next_joining
backward_tree.backward_update(r)
prev_removed = r.trace_backward(t, next_removed)
if prev_removed and prev_removed.upper_node is self.root:
prev_removed = Branch()
if not prev_removed:
backward_tree.forward_update(r)
prev_joining = backward_tree.find_joining_branch(prev_removed)
r.remove(prev_removed, next_removed, prev_joining, next_joining, self.cut_node)
b_it_idx -= 1
next_removed = prev_removed
# ---- update start / end ----
if self.removed_branches:
self.start = self.removed_branches.keys()[0]
self.end = self.removed_branches.keys()[-1]
self._remove_empty_recombinations()
self._remap_mutations()
self.cut_tree.remove(center_branch, self.cut_node)
backward_tree.remove(
self.removed_branches[self.removed_branches.keys()[0]],
self.cut_node,
)
self.start_tree = backward_tree
self.end_tree = forward_tree
# ------------------------------------------------------------------
# MCMC: add
# ------------------------------------------------------------------
[docs]
def add(
self,
new_joining_branches: SortedDict,
added_branches: SortedDict,
) -> None:
"""Thread the removed lineage back in at new positions.
Mirrors ARG::add.
"""
join_keys = list(new_joining_branches.keys())
add_keys = list(added_branches.keys())
join_idx = 0
add_idx = 0
r_keys = list(self.recombinations.keys())
r_start_idx = self.recombinations.bisect_left(self.start)
prev_joining = Branch()
next_joining = Branch()
prev_added = Branch()
next_added = Branch()
r_idx = r_start_idx
while add_idx < len(add_keys):
add_pos = add_keys[add_idx]
if add_pos >= self.sequence_length:
break
# Advance join pointer
if join_idx < len(join_keys) and join_keys[join_idx] == add_pos:
next_joining = new_joining_branches[join_keys[join_idx]]
join_idx += 1
next_added = added_branches[add_pos]
# Advance r_idx past any recombinations before add_pos
while r_idx < len(r_keys) and r_keys[r_idx] < add_pos:
r_idx += 1
# Is there an existing recombination at add_pos?
if r_idx < len(r_keys) and r_keys[r_idx] == add_pos:
r = self.recombinations[r_keys[r_idx]]
r_idx += 1
r.add(prev_added, next_added, prev_joining, next_joining, self.cut_node)
else:
self._new_recombination(
add_pos,
prev_added, prev_joining,
next_added, next_joining,
)
prev_joining = next_joining
prev_added = next_added
add_idx += 1
self._remove_empty_recombinations()
self._impute(new_joining_branches, added_branches)
# Update start_tree
first_join = new_joining_branches[new_joining_branches.keys()[0]]
first_added = added_branches[added_branches.keys()[0]]
self.start_tree.add(first_added, first_join, self.cut_node)
# ------------------------------------------------------------------
# Recombination time sampling
# ------------------------------------------------------------------
[docs]
def approx_sample_recombinations(self) -> None:
"""Sample start_times and finalize derived fields for all recombinations.
Mirrors ARG::approx_sample_recombinations / RSP_smc::approx_sample_recombination.
Sets source_branch, start_time, then calls _find_target_branch and
_find_recomb_info so that merging_branch (used by BSP) is valid.
"""
for pos, r in self.recombinations.items():
if pos == 0 or pos >= self.sequence_length:
continue
if r.start_time > 0:
continue
if not r.deleted_branches:
continue
if r.deleted_node is None or r.inserted_node is None:
continue
# Find source candidates: deleted branches ending at deleted_node
# whose induced recombined branch is in inserted_branches.
source_candidates = []
for b in r.deleted_branches:
if (b.upper_node is r.deleted_node
and b.lower_node.time < r.inserted_node.time):
candidate = Branch(b.lower_node, r.inserted_node)
if r.create(candidate):
source_candidates.append(b)
if len(source_candidates) == 1:
r.source_branch = source_candidates[0]
lb = max(self.cut_time, r.source_branch.lower_node.time)
ub = min(r.source_branch.upper_node.time, r.inserted_node.time)
r.start_time = _choose_time(lb, ub)
elif len(source_candidates) == 2:
lb1 = max(self.cut_time, source_candidates[0].lower_node.time)
lb2 = max(self.cut_time, source_candidates[1].lower_node.time)
ub1 = min(source_candidates[0].upper_node.time, r.inserted_node.time)
ub2 = min(source_candidates[1].upper_node.time, r.inserted_node.time)
q = (ub1 - lb1) / (ub1 + ub2 - lb1 - lb2) if (ub1 + ub2 - lb1 - lb2) > 0 else 0.5
if 0.5 <= q:
r.source_branch = source_candidates[0]
r.start_time = _choose_time(lb1, ub1)
else:
r.source_branch = source_candidates[1]
r.start_time = _choose_time(lb2, ub2)
else:
continue # skip if no valid source found
# Handle degenerate case: deleted and inserted nodes at same time
if r.deleted_node.time == r.inserted_node.time:
r.inserted_node.time = math.nextafter(r.inserted_node.time, math.inf)
r._find_target_branch()
r._find_recomb_info()
ub = min(r.deleted_node.time, r.inserted_node.time)
if r.start_time >= ub:
r.start_time = math.nextafter(ub, -math.inf)
# ------------------------------------------------------------------
# Cut-point sampling (used by Threader / Sampler)
# ------------------------------------------------------------------
[docs]
def sample_internal_cut(self) -> Tuple[float, Branch, float]:
"""Sample a random (pos, branch, cut_time) for the next MCMC step.
Mirrors ARG::sample_internal_cut.
"""
if self.end >= self.sequence_length - 0.1:
self.cut_pos = 0.0
self.cut_tree = self.get_tree_at(0.0)
else:
self.cut_tree = self.end_tree.copy()
self.cut_pos = self.end
b, t = self._tree_sample_cut_point(self.cut_tree)
# Avoid sampling exactly at node boundaries
max_tries = 20
for _ in range(max_tries):
if t != b.lower_node.time and t != b.upper_node.time:
break
b, t = self._tree_sample_cut_point(self.cut_tree)
return (self.cut_pos, b, t)
def _tree_sample_cut_point(self, tree: Tree) -> Tuple[Branch, float]:
"""Sample a (branch, time) pair uniformly over the tree.
Mirrors Tree::sample_cut_point.
"""
# Find root time (max non-inf parent time)
max_time = 0.0
for child, parent in tree.parents.items():
if not math.isinf(parent.time) and parent.time > max_time:
max_time = parent.time
cut_time = self._random() * max_time
candidates = [
Branch(child, parent)
for child, parent in tree.parents.items()
if not math.isinf(parent.time)
and parent.time > cut_time and child.time <= cut_time
]
if not candidates:
# Fallback: pick any branch
candidates = [Branch(c, p) for c, p in tree.parents.items()
if not math.isinf(p.time)]
if not candidates:
raise RuntimeError("No valid branches in tree for cut point sampling")
idx = int(math.floor(len(candidates) * self._random()))
idx = min(len(candidates) - 1, idx)
return candidates[idx], cut_time
# ------------------------------------------------------------------
# Check-points (for BSP / TSP sanity checks)
# ------------------------------------------------------------------
[docs]
def get_check_points(self) -> Set[float]:
"""Return the set of recombination positions that need sanity checks."""
if not self.removed_branches:
return set()
start_pos = self.removed_branches.keys()[0]
end_pos = self.removed_branches.keys()[-1]
deleted_nodes: Dict[Optional[Node], float] = {}
node_spans = []
for pos in self.recombinations.irange(start_pos, end_pos):
r = self.recombinations[pos]
if r.deleted_node is not None:
deleted_nodes[r.deleted_node] = pos
ins = r.inserted_node
if ins is not None and ins in deleted_nodes and ins is not self.root:
node_spans.append((ins, deleted_nodes[ins], pos))
del deleted_nodes[ins]
check_points: Set[float] = set()
for n, x, y in node_spans:
if not self._check_disjoint_nodes(x, y):
check_points.add(y)
return check_points
def _check_disjoint_nodes(self, x: float, y: float) -> bool:
r_start = self.recombinations[x]
if r_start.deleted_node is None:
return False
t = r_start.deleted_node.time
b = r_start.merging_branch
for pos in self.recombinations.irange(x, y - 1e-15):
r = self.recombinations[pos]
b = r.trace_forward(t, b)
if not b:
return False
r_end = self.recombinations.get(y)
if r_end is None:
return False
return b == r_end.target_branch
# ------------------------------------------------------------------
# Utility
# ------------------------------------------------------------------
[docs]
def clear_remove_info(self) -> None:
"""Reset MCMC working state."""
self.removed_branches = SortedDict()
self.joining_branches = SortedDict()
self.cut_node = None
[docs]
def count_flipping(self) -> int:
count = 0
for pos, branches in self.mutation_branches.items():
if len(branches) > 1:
# Check if root branch is in the set
for b in branches:
if b.upper_node is self.root:
count += 1
break
return count
[docs]
def count_incompatibility(self) -> int:
count = 0
for pos, branches in self.mutation_branches.items():
if len(branches) > 1:
has_root = any(b.upper_node is self.root for b in branches)
if has_root:
if len(branches) > 2:
count += 1
else:
count += 1
return count
[docs]
def num_unmapped(self) -> int:
return self.count_incompatibility()
[docs]
def get_arg_length(self) -> float:
"""Total ARG length (sum of branch_length * genomic_span)."""
tree = self.get_tree_at(0.0)
prev_pos = 0.0
total = 0.0
r_iter = iter(self.recombinations.items())
# skip sentinel at 0
next(r_iter)
tree_len = tree.length()
for r_pos, r in r_iter:
next_pos = min(r_pos, self.sequence_length)
total += tree_len * (next_pos - prev_pos)
if r_pos >= self.sequence_length:
break
tree.forward_update(r)
tree_len = tree.length()
prev_pos = next_pos
return total
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _remove_empty_recombinations(self) -> None:
"""Remove recombinations in [start, end] with no branch changes."""
to_delete = []
for pos in self.recombinations.irange(self.start, self.end):
r = self.recombinations[pos]
if (not r.deleted_branches and not r.inserted_branches
and pos < self.sequence_length):
to_delete.append(pos)
for pos in to_delete:
del self.recombinations[pos]
def _new_recombination(
self,
pos: float,
prev_added: Branch,
prev_joining: Branch,
next_added: Branch,
next_joining: Branch,
) -> None:
"""Create a new Recombination record at *pos*.
Mirrors ARG::new_recombination.
"""
deleted = set()
inserted = set()
# Branches to delete (exist before pos, not after)
def _safe_add(s, b):
if b.lower_node is not None and b.upper_node is not None:
s.add(b)
_safe_add(deleted, prev_added)
_safe_add(deleted, Branch(prev_joining.lower_node, prev_added.upper_node))
_safe_add(deleted, Branch(prev_added.upper_node, prev_joining.upper_node))
_safe_add(deleted, next_joining)
_safe_add(inserted, next_added)
_safe_add(inserted, Branch(next_joining.lower_node, next_added.upper_node))
_safe_add(inserted, Branch(next_added.upper_node, next_joining.upper_node))
_safe_add(inserted, prev_joining)
r = Recombination(deleted, inserted)
r.set_pos(pos)
self.recombinations[pos] = r
def _remap_mutations(self) -> None:
"""Update mutation_branches after removing a lineage.
Mirrors ARG::remap_mutations.
"""
if not self.joining_branches or not self.removed_branches:
return
x = self.joining_branches.keys()[0]
y = self.joining_branches.keys()[-1]
join_idx = 0
remove_idx = 0
join_keys = list(self.joining_branches.keys())
remove_keys = list(self.removed_branches.keys())
joining_branch = Branch()
removed_branch = Branch()
mut_keys = list(k for k in self.mutation_branches.keys() if x <= k < y)
for m in mut_keys:
# Advance join pointer
while join_idx < len(join_keys) - 1 and join_keys[join_idx + 1] <= m:
join_idx += 1
while remove_idx < len(remove_keys) - 1 and remove_keys[remove_idx + 1] <= m:
remove_idx += 1
joining_branch = self.joining_branches[join_keys[join_idx]]
removed_branch = self.removed_branches[remove_keys[remove_idx]]
if not joining_branch or not removed_branch:
continue
joining_node = removed_branch.upper_node
if joining_node is None:
continue
lower_branch = Branch(joining_branch.lower_node, joining_node)
upper_branch = Branch(joining_node, joining_branch.upper_node)
branches = self.mutation_branches.get(m, set())
branches.discard(removed_branch)
branches.discard(lower_branch)
branches.discard(upper_branch)
sl = lower_branch.lower_node.get_state(m) if lower_branch.lower_node else 0
su = upper_branch.upper_node.get_state(m) if upper_branch.upper_node else 0
if sl != su:
branches.add(joining_branch)
self.mutation_branches[m] = branches
def _impute(
self,
new_joining_branches: SortedDict,
added_branches: SortedDict,
) -> None:
"""Impute mutation states for newly threaded node.
Mirrors ARG::impute.
"""
add_keys = list(added_branches.keys())
join_keys = list(new_joining_branches.keys())
join_idx = 0
joining_branch = Branch()
added_branch = Branch()
for i, add_pos in enumerate(add_keys[:-1]):
next_add_pos = add_keys[i + 1]
added_branch = added_branches[add_pos]
if join_idx < len(join_keys) and join_keys[join_idx] == add_pos:
joining_branch = new_joining_branches[join_keys[join_idx]]
join_idx += 1
for m in self.mutation_sites.irange(add_pos, next_add_pos - 1e-15):
if m == _INT_MAX:
break
self._map_mutation_branch(m, joining_branch, added_branch)
def _map_mutation_branch(
self,
x: float,
joining_branch: Branch,
added_branch: Branch,
) -> None:
"""Update mutation_branches[x] after threading added_branch.
Mirrors ARG::map_mutation(double, Branch, Branch).
"""
if joining_branch.is_null() or added_branch.is_null():
return
sl = joining_branch.lower_node.get_state(x)
su = joining_branch.upper_node.get_state(x)
s0 = added_branch.lower_node.get_state(x)
sm = 1 if (sl + su + s0 > 1) else 0
added_branch.upper_node.write_state(x, sm)
branches = self.mutation_branches.get(x, set())
if sl != su:
branches.discard(joining_branch)
if sm != sl:
branches.add(Branch(joining_branch.lower_node, added_branch.upper_node))
if sm != su:
branches.add(Branch(added_branch.upper_node, joining_branch.upper_node))
if sm != s0:
branches.add(added_branch)
self.mutation_branches[x] = branches
# ------------------------------------------------------------------
# Repr
# ------------------------------------------------------------------
def __repr__(self) -> str:
n_recombs = len(self.recombinations) - 2 # exclude sentinels
return (
f"ARG(samples={len(self.sample_nodes)}, "
f"recombs={n_recombs}, "
f"length={self.sequence_length})"
)