"""
Threader — MCMC move: remove a lineage and re-thread it.
Mirrors Threader_smc.cpp / Threader_smc.hpp.
Two public entry points:
thread(arg, node) — add a new leaf node (initial threading)
internal_rethread(arg, cut) — MCMC move with Metropolis acceptance
"""
from __future__ import annotations
import math
from typing import Dict, Optional, Set, Tuple, TYPE_CHECKING
import numpy as np
from sortedcontainers import SortedDict
from ..data.branch import Branch
from ..data.node import Node
from ..hmm.bsp import BSP
from ..hmm.tsp import TSP
from ..hmm.emission import BinaryEmission, PolarEmission
if TYPE_CHECKING:
from ..data.arg import ARG
from ..data.recombination import Recombination
[docs]
class Threader:
"""BSP + TSP threader for adding/rethreading a lineage in an ARG.
Parameters
----------
cutoff : float
BSP state-space pruning threshold (bsp_c in C++).
gap : float
TSP time grid quantile gap (tsp_q in C++).
"""
def __init__(self, cutoff: float = 0.0, gap: float = 0.02) -> None:
self.cutoff = cutoff
self.gap = gap
self.bsp: BSP = BSP()
self.tsp: TSP = TSP()
# Emission models: polar (BSP) and binary (TSP)
self.pe: PolarEmission = PolarEmission()
self.be: BinaryEmission = BinaryEmission()
self.cut_time: float = 0.0
self.start: float = 0.0
self.end: float = 0.0
self.start_index: int = 0
self.end_index: int = 0
# Results of a threading run
self.new_joining_branches: SortedDict = SortedDict()
self.added_branches: SortedDict = SortedDict()
self._rng: np.random.Generator = np.random.default_rng()
[docs]
def set_rng(self, rng: np.random.Generator) -> None:
self._rng = rng
self.bsp.set_rng(rng)
self.tsp.set_rng(rng)
def _random(self) -> float:
return float(self._rng.uniform())
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
def thread(self, arg: "ARG", node: Node) -> None:
"""Add *node* as a new leaf and thread its lineage through *arg*.
Mirrors Threader_smc::thread.
"""
self.cut_time = 0.0
arg.cut_time = 0.0
arg.add_sample(node)
self._get_boundary(arg)
self._run_bsp(arg)
self._sample_joining_branches(arg)
self._run_tsp(arg)
self._sample_joining_points(arg)
arg.add(self.new_joining_branches, self.added_branches)
arg.approx_sample_recombinations()
arg.clear_remove_info()
[docs]
def internal_rethread(
self,
arg: "ARG",
cut_point: Tuple[float, Branch, float],
) -> None:
"""MCMC move: remove and re-thread a lineage segment.
Proposes a new threading, accepts with Metropolis ratio.
Mirrors Threader_smc::internal_rethread.
"""
self.cut_time = cut_point[2]
arg.remove(cut_point)
self._get_boundary(arg)
self._set_check_points(arg)
self._run_bsp(arg)
self._sample_joining_branches(arg)
self._run_tsp(arg)
self._sample_joining_points(arg)
ar = self._acceptance_ratio(arg)
if self._random() < ar:
arg.add(self.new_joining_branches, self.added_branches)
else:
arg.add(arg.joining_branches, arg.removed_branches)
arg.approx_sample_recombinations()
arg.clear_remove_info()
# ------------------------------------------------------------------
# Private: boundary and check points
# ------------------------------------------------------------------
def _get_boundary(self, arg: "ARG") -> None:
self.start = arg.start
self.end = arg.end
self.start_index = arg.get_index(self.start)
self.end_index = arg.get_index(self.end)
def _set_check_points(self, arg: "ARG") -> None:
check_points = arg.get_check_points()
self.bsp.set_check_points(check_points)
self.tsp.set_check_points(check_points)
# ------------------------------------------------------------------
# Private: BSP forward pass
# ------------------------------------------------------------------
def _run_bsp(self, arg: "ARG") -> None:
"""Run the BSP forward pass over *arg*.
Mirrors Threader_smc::run_BSP.
"""
self.bsp.reserve_memory(self.end_index - self.start_index)
self.bsp.set_cutoff(self.cutoff)
self.bsp.set_emission(self.pe)
self.bsp.start(arg.start_tree, self.cut_time)
recomb_it = arg.recombinations.irange(
None, None, inclusive=(False, True)
)
# Advance past start
recomb_keys = list(arg.recombinations.irange(
self.start, None, inclusive=(False, True)
))
recomb_idx = 0
mut_keys = sorted(m for m in arg.mutation_sites if m >= self.start)
mut_idx = 0
removed_items = list(arg.removed_branches.items())
query_idx = 0
query_node: Optional[Node] = None
for i in range(self.start_index, self.end_index):
pos = arg.coordinates[i]
# Advance query node
if query_idx < len(removed_items) and pos == removed_items[query_idx][0]:
query_node = removed_items[query_idx][1].lower_node
query_idx += 1
# Recombination or forward step
if recomb_idx < len(recomb_keys) and pos == recomb_keys[recomb_idx]:
r = arg.recombinations[recomb_keys[recomb_idx]]
recomb_idx += 1
self.bsp.transfer(r)
elif pos != self.start:
self.bsp.forward(arg.rhos[i - 1])
# Collect mutations in [pos, next_pos)
next_pos = arg.coordinates[i + 1]
mut_set: Set[float] = set()
while mut_idx < len(mut_keys) and mut_keys[mut_idx] < next_pos:
mut_set.add(mut_keys[mut_idx])
mut_idx += 1
if mut_set:
self.bsp.mut_emit(
arg.thetas[i],
next_pos - pos,
mut_set,
query_node,
)
else:
self.bsp.null_emit(arg.thetas[i], query_node)
# Sanity check at end boundary
if self.end in self.bsp.check_points:
r = arg.recombinations[self.end]
self.bsp.sanity_check(r)
# ------------------------------------------------------------------
# Private: TSP forward pass
# ------------------------------------------------------------------
def _run_tsp(self, arg: "ARG") -> None:
"""Run the TSP forward pass over the sampled joining branches.
Mirrors Threader_smc::run_TSP.
"""
self.tsp.reserve_memory(self.end_index - self.start_index)
self.tsp.set_gap(self.gap)
self.tsp.set_emission(self.be)
start_branch = self.new_joining_branches.peekitem(0)[1] # first value
self.tsp.start(start_branch, self.cut_time)
recomb_keys = list(arg.recombinations.irange(
self.start, None, inclusive=(False, True)
))
recomb_idx = 0
join_keys = list(self.new_joining_branches.irange(
self.start, None, inclusive=(False, True)
))
join_idx = 0
mut_keys = sorted(m for m in arg.mutation_sites if m >= self.start)
mut_idx = 0
removed_items = list(arg.removed_branches.items())
query_idx = 0
query_node: Optional[Node] = None
prev_branch = start_branch
next_branch = start_branch
for i in range(self.start_index, self.end_index):
pos = arg.coordinates[i]
# Advance query node
if query_idx < len(removed_items) and pos == removed_items[query_idx][0]:
query_node = removed_items[query_idx][1].lower_node
query_idx += 1
# Advance joining branch
if join_idx < len(join_keys) and pos == join_keys[join_idx]:
next_branch = self.new_joining_branches[join_keys[join_idx]]
join_idx += 1
# Transfer / recombine / forward
if recomb_idx < len(recomb_keys) and pos == recomb_keys[recomb_idx]:
r = arg.recombinations[recomb_keys[recomb_idx]]
recomb_idx += 1
self.tsp.transfer(r, prev_branch, next_branch)
prev_branch = next_branch
elif prev_branch is not next_branch:
self.tsp.recombine(prev_branch, next_branch)
prev_branch = next_branch
elif pos != self.start:
self.tsp.forward(arg.rhos[i])
# Collect mutations in [pos, next_pos)
next_pos = arg.coordinates[i + 1]
mut_set: Set[float] = set()
while mut_idx < len(mut_keys) and mut_keys[mut_idx] < next_pos:
mut_set.add(mut_keys[mut_idx])
mut_idx += 1
if mut_set:
self.tsp.mut_emit(
arg.thetas[i],
next_pos - pos,
mut_set,
query_node,
)
else:
self.tsp.null_emit(arg.thetas[i], query_node)
# Sanity check at end boundary
if self.end in self.tsp.check_points:
r = arg.recombinations[self.end]
self.tsp._sanity_check(r)
# ------------------------------------------------------------------
# Private: sampling
# ------------------------------------------------------------------
def _sample_joining_branches(self, arg: "ARG") -> None:
"""Sample joining branches from BSP traceback.
Mirrors Threader_smc::sample_joining_branches.
"""
self.new_joining_branches = self.bsp.sample_joining_branches(
self.start_index, arg.coordinates
)
def _sample_joining_points(self, arg: "ARG") -> None:
"""Sample joining nodes from TSP traceback and build added_branches.
Mirrors Threader_smc::sample_joining_points.
"""
added_nodes: Dict[float, Optional[Node]] = self.tsp.sample_joining_nodes(
self.start_index, arg.coordinates
)
self.added_branches = SortedDict()
for x, added_node in added_nodes.items():
if added_node is None:
self.added_branches[x] = Branch() # null sentinel at sequence_length boundary
else:
query_node = arg.get_query_node_at(x)
self.added_branches[x] = Branch(query_node, added_node)
# ------------------------------------------------------------------
# Private: acceptance ratio
# ------------------------------------------------------------------
def _acceptance_ratio(self, arg: "ARG") -> float:
"""Compute Metropolis acceptance ratio.
Mirrors Threader_smc::acceptance_ratio.
"""
# Height of the cut tree: max time among CHILD nodes (keys of parents dict).
# C++: cut_tree.parents.rbegin()->first->time (rbegin = max-time key).
# Root is never a key (it has no parent), so we never get time=inf here.
cut_height = max(
(child.time for child in arg.cut_tree.parents.keys()),
default=0.0,
)
old_height = cut_height
new_height = cut_height
# Find old joining branch at cut_pos
old_join_keys = [k for k in arg.joining_branches.keys() if k <= arg.cut_pos]
old_join_branch = arg.joining_branches[max(old_join_keys)] if old_join_keys else None
new_join_keys = [k for k in self.new_joining_branches.keys() if k <= arg.cut_pos]
new_join_branch = self.new_joining_branches[max(new_join_keys)] if new_join_keys else None
old_add_keys = [k for k in arg.removed_branches.keys() if k <= arg.cut_pos]
old_add_branch = arg.removed_branches[max(old_add_keys)] if old_add_keys else None
new_add_keys = [k for k in self.added_branches.keys() if k <= arg.cut_pos]
new_add_branch = self.added_branches[max(new_add_keys)] if new_add_keys else None
if old_join_branch is not None and old_join_branch.upper_node is arg.root:
old_height = old_add_branch.upper_node.time if old_add_branch else cut_height
if new_join_branch is not None and new_join_branch.upper_node is arg.root:
new_height = new_add_branch.upper_node.time if new_add_branch else cut_height
if new_height <= 0:
return 1.0
return old_height / new_height