Source code for pysinger.hmm.tsp

"""
TSP — Time Sequence Propagator, the forward HMM for coalescence times.

Mirrors TSP_smc.cpp / TSP_smc.hpp.

The TSP operates on a single Branch and samples a representative
coalescence time in each genomic bin.  It uses a PSMC-style transition
kernel (psmc_prob) rather than the coalescent-CDF-based kernel of BSP.

Key public API
--------------
start(branch, t)                     — initialise at the left boundary
forward(rho)                         — advance one bin
transfer(r, prev_branch, next_branch)— apply topology change
recombine(prev_branch, next_branch)  — re-sample at a recombination
null_emit(theta, query_node)         — apply no-mutation emission
mut_emit(theta, bin_size, mut_set, query_node) — apply mutation emission
sample_joining_nodes(start_index, coordinates) — traceback → Dict[pos, Node]
"""
from __future__ import annotations

import math
from typing import Dict, List, Optional, Set, Tuple, TYPE_CHECKING

import numpy as np
from sortedcontainers import SortedDict

from ..data.branch import Branch
from ..data.interval import Interval
from ..data.node import Node

if TYPE_CHECKING:
    from ..data.recombination import Recombination
    from .emission import Emission


_EPSILON = 1e-20
_counter: int = 0  # mirrors TSP_smc::counter (static)


def _new_node(t: float) -> Node:
    """Create an unnamed node at time *t* (mirrors new_node helper)."""
    n = Node(time=t)
    return n


[docs] class TSP: """Time Sequence Propagator — forward HMM for the time dimension. State space: a list of Interval objects (finely gridded over a Branch). Forward probabilities: forward_probs[step][interval_idx]. Mirrors TSP_smc in the C++ code. """ def __init__(self) -> None: global _counter self.cut_time: float = 0.0 self.lower_bound: float = 0.0 self.gap: float = 0.02 # default quantile gap for grid generation self.eh: Optional["Emission"] = None self.check_points: Set[float] = set() self.curr_index: int = 0 self.curr_branch: Optional[Branch] = None self.curr_intervals: List[Interval] = [] self.forward_probs: List[List[float]] = [] self.state_spaces: Dict[int, List[Interval]] = {} # int → List[Interval] self.source_interval: Dict[int, Interval] = {} # id(new_iv) → old Interval self.rhos: List[float] = [] # per-step work arrays (resized in set_dimensions) self.dim: int = 0 self.diagonals: List[float] = [] self.lower_diagonals: List[float] = [] self.upper_diagonals: List[float] = [] self.lower_sums: List[float] = [] self.upper_sums: List[float] = [] self.null_emit_probs: List[float] = [] self.mut_emit_probs: List[float] = [] self.factors: List[float] = [] self.trace_back_probs: List[float] = [] self.emissions: List[float] = [0.0, 0.0, 0.0, 0.0] self.sample_index: int = 0 self.prev_rho: float = -1.0 self.prev_theta: float = -1.0 self.prev_node: Optional[Node] = None # temp buffer for building forward_probs at current step self._temp: List[float] = [] self._rng: np.random.Generator = np.random.default_rng()
[docs] def set_rng(self, rng: np.random.Generator) -> None: self._rng = rng
def _random(self) -> float: return float(self._rng.uniform()) # ------------------------------------------------------------------ # Configuration # ------------------------------------------------------------------
[docs] def set_gap(self, q: float) -> None: self.gap = q
[docs] def set_emission(self, e: "Emission") -> None: self.eh = e
[docs] def set_check_points(self, p: Set[float]) -> None: self.check_points = p
[docs] def reserve_memory(self, length: int) -> None: self.forward_probs = [] self.rhos = []
# ------------------------------------------------------------------ # Initialisation # ------------------------------------------------------------------
[docs] def start(self, branch: Branch, t: float) -> None: """Initialise forward pass at the left boundary. Mirrors TSP_smc::start. """ self.cut_time = t self.curr_index = 0 self.curr_branch = branch self._temp = [] self.curr_intervals = [] self.source_interval = {} self.state_spaces = {} self.forward_probs = [] self.rhos = [] self.prev_rho = -1.0 self.prev_theta = -1.0 self.prev_node = None lb_start = branch.lower_node.time ub_start = branch.upper_node.time self.lower_bound = max(t, lb_start) self._generate_intervals(branch, lb_start, ub_start) # initial forward probs: exp(-lb) - exp(-ub) for each interval for iv in self.curr_intervals: self._temp.append(math.exp(-iv.lb) - math.exp(-iv.ub)) self.state_spaces[0] = list(self.curr_intervals) self.forward_probs.append(list(self._temp)) self._temp = [] self._set_dimensions() self._compute_factors()
# ------------------------------------------------------------------ # Forward step # ------------------------------------------------------------------
[docs] def forward(self, rho: float) -> None: """Advance one bin. Mirrors TSP_smc::forward. """ self.rhos.append(rho) if self.dim == 0: self.curr_index += 1 self.prev_rho = rho self.forward_probs.append([]) return self._compute_diagonals(rho) self._compute_lower_diagonals(rho) self._compute_upper_diagonals(rho) self._compute_lower_sums() self._compute_upper_sums() self.curr_index += 1 self.prev_rho = rho new_fp = list(self.lower_sums) # copy for i in range(self.dim): new_fp[i] += ( self.diagonals[i] * self.forward_probs[self.curr_index - 1][i] + self.lower_diagonals[i] * self.upper_sums[i] ) if self.curr_intervals[i].lb != self.curr_intervals[i].ub: new_fp[i] = max(_EPSILON, new_fp[i]) self.forward_probs.append(new_fp)
# ------------------------------------------------------------------ # Transfer at topology change # ------------------------------------------------------------------
[docs] def transfer( self, r: "Recombination", prev_branch: Branch, next_branch: Branch, ) -> None: """Apply topology change *r*. Mirrors TSP_smc::transfer. """ self.rhos.append(0.0) self.prev_rho = -1.0 self.prev_theta = -1.0 self.prev_node = None self._sanity_check(r) self.curr_index += 1 self.curr_branch = next_branch self.lower_bound = max(self.cut_time, next_branch.lower_node.time) # Constrain previous step's probs based on topology if prev_branch == r.source_branch and next_branch == r.merging_branch: self._set_interval_constraint(r) elif prev_branch == r.target_branch and next_branch == r.recombined_branch: self._set_point_constraint(r) self.curr_intervals = [] self._temp = [] if prev_branch == r.source_branch and next_branch == r.merging_branch: # Switch to a point mass at deleted_node.time t = r.deleted_node.time # Clamp to branch range so rescaling can't produce out-of-bounds t lb_b = next_branch.lower_node.time ub_b = next_branch.upper_node.time t = max(lb_b, min(ub_b, t)) self._generate_intervals(next_branch, lb_b, t) n_before = len(self.curr_intervals) self._generate_intervals(next_branch, t, t) # Mark point interval as point mass (only if one was actually added) if len(self.curr_intervals) > n_before: self._temp[-1] = 1.0 self.curr_intervals[-1].node = r.deleted_node self._generate_intervals(next_branch, t, ub_b) elif prev_branch == r.target_branch and next_branch == r.recombined_branch: # Switch from a point mass self._generate_intervals( next_branch, next_branch.lower_node.time, r.start_time ) self._generate_intervals( next_branch, r.start_time, next_branch.upper_node.time ) for i, iv in enumerate(self.curr_intervals): if iv.time >= r.start_time: self._temp[i] = 1.0 else: # Regular transfer: overlap of prev and next branch intervals lb = next_branch.lower_node.time ub = max(prev_branch.lower_node.time, next_branch.lower_node.time) self._generate_intervals(next_branch, lb, ub) self._transfer_intervals(r, prev_branch, next_branch) if self.curr_intervals: lb2 = min(self.curr_intervals[-1].ub, next_branch.upper_node.time) else: lb2 = next_branch.lower_node.time ub2 = next_branch.upper_node.time self._generate_intervals(next_branch, lb2, ub2) self.state_spaces[self.curr_index] = list(self.curr_intervals) self.forward_probs.append(list(self._temp)) self._temp = [] self._set_dimensions() self._compute_factors()
# ------------------------------------------------------------------ # Recombination (full re-sample) # ------------------------------------------------------------------
[docs] def recombine(self, prev_branch: Branch, next_branch: Branch) -> None: """Re-sample interval distribution at a recombination. Mirrors TSP_smc::recombine. """ prev_intervals = list(self.curr_intervals) prev_fp = list(self.forward_probs[self.curr_index]) self.curr_intervals = [] self._temp = [] self.rhos.append(0.0) self.prev_rho = -1.0 self.prev_theta = -1.0 self.prev_node = None self.curr_branch = next_branch self.curr_index += 1 self.lower_bound = max(self.cut_time, next_branch.lower_node.time) self._generate_intervals( next_branch, next_branch.lower_node.time, next_branch.upper_node.time, ) # Initialize new step fp to zero, then accumulate new_fp = [0.0] * len(self.curr_intervals) self.forward_probs.append(new_fp) self.state_spaces[self.curr_index] = list(self.curr_intervals) self._set_dimensions() self._compute_factors() lb_full = self.curr_intervals[0].lb ub_full = self.curr_intervals[-1].ub for i, prev_iv in enumerate(prev_intervals): base = self._recomb_prob(prev_iv.time, lb_full, ub_full) for j, curr_iv in enumerate(self.curr_intervals): if base == 0: new_prob = 1.0 else: new_prob = ( self._recomb_prob(prev_iv.time, curr_iv.lb, curr_iv.ub) * prev_fp[i] / base ) new_fp[j] += new_prob + _EPSILON self._temp = []
# ------------------------------------------------------------------ # Emission # ------------------------------------------------------------------
[docs] def null_emit(self, theta: float, query_node: "Node") -> None: """Apply null emission (no mutations). Mirrors TSP_smc::null_emit. """ if self.dim == 0: return self._compute_null_emit_probs(theta, query_node) self.prev_theta = theta self.prev_node = query_node fp = self.forward_probs[self.curr_index] ws = 0.0 for i in range(self.dim): fp[i] *= self.null_emit_probs[i] ws += fp[i] if ws > 0: for i in range(self.dim): fp[i] /= ws else: for i in range(self.dim): fp[i] = 1.0 / self.dim
[docs] def mut_emit( self, theta: float, bin_size: float, mut_set: Set[float], query_node: "Node", ) -> None: """Apply mutation emission. Mirrors TSP_smc::mut_emit. """ if self.dim == 0: return self._compute_mut_emit_probs(theta, bin_size, mut_set, query_node) fp = self.forward_probs[self.curr_index] ws = 0.0 for i in range(self.dim): fp[i] *= self.mut_emit_probs[i] ws += fp[i] if ws > 0: for i in range(self.dim): fp[i] /= ws else: for i in range(self.dim): fp[i] = 1.0 / self.dim if self.dim > 0 else 0.0
# ------------------------------------------------------------------ # Traceback # ------------------------------------------------------------------
[docs] def sample_joining_nodes( self, start_index: int, coordinates: List[float], ) -> Dict[float, Optional[Node]]: """Traceback to sample a joining-node map. Returns Dict[pos → Node]. Mirrors TSP_smc::sample_joining_nodes. """ self.prev_rho = -1.0 joining_nodes: Dict[float, Optional[Node]] = {} x = self.curr_index pos = coordinates[x + start_index + 1] interval = self._sample_curr_interval(x) n = self._sample_joining_node(interval) joining_nodes[pos] = None # sentinel for rightmost pos while x >= 0: x = self._trace_back_helper(interval, x) pos = coordinates[x + start_index] joining_nodes[pos] = n if x == 0: break if x == interval.start_pos: if id(interval) in self.source_interval: x -= 1 interval = self._sample_source_interval(interval, x) else: x -= 1 interval = self._sample_recomb_interval(interval, x) n = self._sample_joining_node(interval) else: x -= 1 interval = self._sample_prev_interval(interval, x) n = self._sample_joining_node(interval) self.prev_rho = -1.0 return joining_nodes
# ------------------------------------------------------------------ # Private: grid and interval generation # ------------------------------------------------------------------ def _get_exp_quantile(self, p: float) -> float: """Inverse CDF of Exp(1): -log(1-p). Mirrors TSP_smc::get_exp_quantile.""" if p < 1e-6: return 0.0 if 1.0 - p < 1e-6: return math.inf return -math.log(1.0 - p) def _generate_grid(self, lb: float, ub: float) -> List[float]: """Generate a quantile-spaced grid in [lb, ub]. Mirrors TSP_smc::generate_grid. """ lq = 1.0 - math.exp(-lb) uq = 1.0 - math.exp(-ub) q = uq - lq n = math.ceil(q / self.gap) points = [lb] for i in range(1, n): l = self._get_exp_quantile(lq + i * q / n) points.append(l) points.append(ub) return points def _generate_intervals( self, branch: Branch, lb: float, ub: float ) -> None: """Append fine-grid intervals for [lb, ub] on *branch*. Mirrors TSP_smc::generate_intervals. """ lb = max(self.cut_time, lb) ub = max(self.cut_time, ub) if lb > ub: return # inverted bounds: mirrors C++ silently producing empty grid if lb == ub: # Point interval: skip boundary points lower_bound_iv = max(self.cut_time, branch.lower_node.time) upper_bound_iv = branch.upper_node.time if lb == lower_bound_iv or lb == upper_bound_iv: return iv = Interval(branch, lb, ub, self.curr_index) iv.fill_time() self.curr_intervals.append(iv) self._temp.append(0.0) return points = self._generate_grid(lb, ub) for i in range(len(points) - 1): l, u = points[i], points[i + 1] iv = Interval(branch, l, u, self.curr_index) iv.fill_time() self.curr_intervals.append(iv) self._temp.append(0.0) def _transfer_intervals( self, r: "Recombination", prev_branch: Branch, next_branch: Branch, ) -> None: """Transfer probability mass from prev to overlapping new intervals. Mirrors TSP_smc::transfer_intervals. """ prev_intervals = self._get_state_space(self.curr_index - 1) for i, interval in enumerate(prev_intervals): lb = max(interval.lb, next_branch.lower_node.time) ub = min(interval.ub, next_branch.upper_node.time) if prev_branch is r.source_branch: ub = min(ub, r.start_time) if lb == r.start_time: continue if lb == ub == next_branch.upper_node.time: continue if lb == ub == next_branch.lower_node.time: continue if ub >= lb: w = self._get_prop(lb, ub, interval.lb, interval.ub) p = w * self.forward_probs[self.curr_index - 1][i] new_iv = Interval(next_branch, lb, ub, self.curr_index) new_iv.fill_time() new_iv.node = interval.node self.source_interval[id(new_iv)] = interval self.curr_intervals.append(new_iv) self._temp.append(p) # ------------------------------------------------------------------ # Private: dimension setup and transition factors # ------------------------------------------------------------------ def _set_dimensions(self) -> None: self.dim = len(self.curr_intervals) self.diagonals = [0.0] * self.dim self.lower_diagonals = [0.0] * self.dim self.upper_diagonals = [0.0] * self.dim self.lower_sums = [0.0] * self.dim self.upper_sums = [0.0] * self.dim self.null_emit_probs = [0.0] * self.dim self.mut_emit_probs = [0.0] * self.dim self.factors = [0.0] * self.dim def _compute_factors(self) -> None: """Pre-compute ratio factors for lower_sums recursion. Mirrors TSP_smc::compute_factors. """ if self.dim == 0: return self.factors[0] = 0.0 for i in range(1, self.dim): iv_prev = self.curr_intervals[i - 1] iv_curr = self.curr_intervals[i] if iv_prev.lb == iv_prev.ub: self.factors[i] = 0.0 elif iv_prev.ub - iv_prev.lb < 1e-4: self.factors[i] = 5.0 else: num = math.exp(-iv_curr.lb) - math.exp(-iv_curr.ub) den = math.exp(-iv_prev.lb) - math.exp(-iv_prev.ub) self.factors[i] = min(num / den, 5.0) if den != 0 else 5.0 # ------------------------------------------------------------------ # Private: PSMC transition kernel # ------------------------------------------------------------------ def _recomb_cdf(self, s: float, t: float) -> float: """Recombination CDF. Mirrors TSP_smc::recomb_cdf.""" if math.isinf(t): return 1.0 if t == 0: return 0.0 l = s - self.cut_time if s > t: cdf = t + math.expm1(self.cut_time - t) - self.cut_time else: cdf = s + math.expm1(self.cut_time - t) - math.expm1(s - t) - self.cut_time cdf = cdf / l return cdf def _recomb_prob(self, s: float, t1: float, t2: float) -> float: """P(recombination targets [t1, t2] | current time s). Mirrors TSP_smc::recomb_prob. """ if s - max(self.lower_bound, self.cut_time) < 0.005: return math.exp(-t1) - math.exp(-t2) pl = self._recomb_cdf(s, t1) pu = self._recomb_cdf(s, t2) p = pu - pl p = max(p, 1e-5) return p def _psmc_cdf(self, rho: float, s: float, t: float) -> float: """PSMC CDF. Mirrors TSP_smc::psmc_cdf.""" l = 2.0 * s - self.lower_bound - self.cut_time if l == 0: pre_factor = rho else: pre_factor = (1.0 - math.exp(-rho * l)) / l if t == self.cut_time and t == self.lower_bound: return 0.0 elif t <= s: integral = ( 2.0 * t + math.exp(-t) * (math.exp(self.cut_time) + math.exp(self.lower_bound)) - self.cut_time - self.lower_bound - 2.0 ) else: integral = ( 2.0 * s + math.exp(self.cut_time - t) + math.exp(self.lower_bound - t) - 2.0 * math.exp(s - t) - self.cut_time - self.lower_bound ) return pre_factor * integral def _psmc_prob(self, rho: float, s: float, t1: float, t2: float) -> float: """PSMC probability over interval [t1, t2] given source at s. Mirrors TSP_smc::psmc_prob. """ l = 2.0 * s - self.lower_bound - self.cut_time if t1 == s == t2: base = math.exp(-rho * l) elif t1 < s < t2: base = math.exp(-rho * l) else: base = 0.0 gap = 0.0 if t2 > t1: uq = self._psmc_cdf(rho, s, t2) lq = self._psmc_cdf(rho, s, t1) gap = max(uq - lq, 0.0) prob = base + gap # clamp to [0, 1] return max(0.0, min(1.0, prob)) def _get_prop(self, lb1: float, ub1: float, lb2: float, ub2: float) -> float: """Proportion of [lb2, ub2] occupied by [lb1, ub1] in exponential measure.""" if ub2 - lb2 < 1e-6: return 1.0 p1 = math.exp(-lb1) - math.exp(-ub1) p2 = math.exp(-lb2) - math.exp(-ub2) return p1 / p2 if p2 > 0 else 1.0 # ------------------------------------------------------------------ # Private: transition matrix computation # ------------------------------------------------------------------ def _compute_diagonals(self, rho: float) -> None: """Compute stay-in-place probabilities. Mirrors compute_diagonals.""" if rho == self.prev_rho: return lb = self.curr_intervals[0].lb ub = self.curr_intervals[-1].ub for i, iv in enumerate(self.curr_intervals): base = self._psmc_prob(rho, iv.time, lb, ub) diag = self._psmc_prob(rho, iv.time, iv.lb, iv.ub) self.diagonals[i] = diag / base if base > 0 else 0.0 def _compute_lower_diagonals(self, rho: float) -> None: """Compute lower off-diagonal (from below). Mirrors compute_lower_diagonals.""" if rho == self.prev_rho: return lb = max(self.cut_time, self.curr_intervals[0].lb) ub = self.curr_intervals[-1].ub self.lower_diagonals[self.dim - 1] = 0.0 for i in range(self.dim - 1): t = self.curr_intervals[i + 1].time base = self._psmc_prob(rho, t, lb, ub) ld = self._psmc_prob(rho, t, self.curr_intervals[i].lb, self.curr_intervals[i].ub) self.lower_diagonals[i] = ld / base if base > 0 else 0.0 def _compute_upper_diagonals(self, rho: float) -> None: """Compute upper off-diagonal (from above). Mirrors compute_upper_diagonals.""" if rho == self.prev_rho: return lb = max(self.cut_time, self.curr_intervals[0].lb) ub = self.curr_intervals[-1].ub self.upper_diagonals[0] = 0.0 for i in range(1, self.dim): t = self.curr_intervals[i - 1].time base = self._psmc_prob(rho, t, lb, ub) ud = self._psmc_prob(rho, t, self.curr_intervals[i].lb, self.curr_intervals[i].ub) self.upper_diagonals[i] = ud / base if base > 0 else 0.0 def _compute_lower_sums(self) -> None: """Cumulative sum for lower triangular part. Mirrors compute_lower_sums.""" self.lower_sums[0] = 0.0 fp = self.forward_probs[self.curr_index] for i in range(1, self.dim): self.lower_sums[i] = ( self.upper_diagonals[i] * fp[i - 1] + self.factors[i] * self.lower_sums[i - 1] ) def _compute_upper_sums(self) -> None: """Partial sums for upper triangular part. Mirrors compute_upper_sums.""" fp = self.forward_probs[self.curr_index] # upper_sums[i] = sum(fp[i+1:]) self.upper_sums[self.dim - 1] = 0.0 for i in range(self.dim - 2, -1, -1): self.upper_sums[i] = fp[i + 1] + self.upper_sums[i + 1] # ------------------------------------------------------------------ # Private: emission computation # ------------------------------------------------------------------ def _compute_emissions( self, mut_set: Set[float], branch: Branch, node: "Node" ) -> None: """Compute diff counts for pre-computed emission. Mirrors TSP_smc::compute_emissions. """ self.emissions = [0.0, 0.0, 0.0, 0.0] for x in mut_set: sl = branch.lower_node.get_state(x) su = branch.upper_node.get_state(x) s0 = node.get_state(x) sm = 1.0 if (sl + su + s0 > 1.5) else 0.0 self.emissions[0] += abs(sm - sl) self.emissions[1] += abs(sm - su) self.emissions[2] += abs(sm - s0) self.emissions[3] += abs(sl - su) def _compute_null_emit_probs(self, theta: float, query_node: "Node") -> None: if theta == self.prev_theta and query_node is self.prev_node: return for i, iv in enumerate(self.curr_intervals): self.null_emit_probs[i] = self.eh.null_emit( self.curr_branch, iv.time, theta, query_node ) def _compute_mut_emit_probs( self, theta: float, bin_size: float, mut_set: Set[float], query_node: "Node", ) -> None: self._compute_emissions(mut_set, self.curr_branch, query_node) for i, iv in enumerate(self.curr_intervals): self.mut_emit_probs[i] = self.eh.emit( self.curr_branch, iv.time, theta, bin_size, self.emissions, query_node ) def _compute_trace_back_probs( self, rho: float, interval: Interval, intervals: List[Interval], ) -> None: """Compute traceback probability for each interval. Mirrors TSP_smc::compute_trace_back_probs. """ if rho == self.prev_rho: return self.trace_back_probs = [0.0] * len(intervals) for i, iv in enumerate(intervals): p = self._psmc_prob(rho, iv.time, interval.lb, interval.ub) if iv.lb < iv.ub: p = max(_EPSILON, p) self.trace_back_probs[i] = p # ------------------------------------------------------------------ # Private: constraint setting at recombinations # ------------------------------------------------------------------ def _sanity_check(self, r: "Recombination") -> None: """Zero point-mass intervals on wrong branch at inserted_node time. Mirrors TSP_smc::sanity_check. """ for i, iv in enumerate(self.curr_intervals): if (iv.lb == iv.ub and iv.lb == r.inserted_node.time and iv.branch is not r.target_branch): self.forward_probs[self.curr_index][i] = 0.0 def _set_interval_constraint(self, r: "Recombination") -> None: """Zero/clamp previous probs at source→merging transition. Mirrors TSP_smc::set_interval_constraint. """ intervals = self._get_state_space(self.curr_index - 1) for i, iv in enumerate(intervals): if iv.ub <= r.start_time: self.forward_probs[self.curr_index - 1][i] = 0.0 else: iv.lb = max(r.start_time, iv.lb) iv.fill_time() def _set_point_constraint(self, r: "Recombination") -> None: """Collapse previous probs onto point mass at inserted_node. Mirrors TSP_smc::set_point_constraint. """ point_iv = self._search_point_interval(r) intervals = self._get_state_space(self.curr_index - 1) for i, iv in enumerate(intervals): if iv is point_iv: self.forward_probs[self.curr_index - 1][i] = 1.0 iv.node = r.inserted_node else: self.forward_probs[self.curr_index - 1][i] = 0.0 def _search_point_interval(self, r: "Recombination") -> Optional[Interval]: """Find the interval that should receive the point mass. Mirrors TSP_smc::search_point_interval. """ t = r.inserted_node.time point_iv: Optional[Interval] = None # Priority 1: proper containing interval for iv in self.curr_intervals: if iv.lb < t < iv.ub: point_iv = iv # Priority 2: exact point interval for iv in self.curr_intervals: if iv.lb == iv.ub == t: point_iv = iv if point_iv is not None: return point_iv # Fallback: two candidate intervals straddle t; pick one not coming # from inserted_node's branch via source_interval chain. candidates = [iv for iv in self.curr_intervals if iv.lb <= t <= iv.ub] if len(candidates) < 2: raise RuntimeError(f"TSP transfer_sample: expected 2 candidate intervals, got {len(candidates)}") test_iv = candidates[0] while id(test_iv) in self.source_interval: test_iv = self.source_interval[id(test_iv)] if (test_iv.branch.upper_node is r.inserted_node or test_iv.branch.lower_node is r.inserted_node): return candidates[1] return candidates[0] # ------------------------------------------------------------------ # Private: state-space lookup # ------------------------------------------------------------------ def _get_state_space(self, x: int) -> List[Interval]: """Return the state space valid at step x (floor lookup).""" # Find the largest key <= x keys = sorted(k for k in self.state_spaces.keys() if k <= x) if not keys: return self.curr_intervals return self.state_spaces[keys[-1]] def _get_prev_breakpoint(self, x: int) -> int: """Return the largest state-space key <= x.""" keys = sorted(k for k in self.state_spaces.keys() if k <= x) if not keys: return 0 return keys[-1] def _get_interval_index( self, interval: Interval, intervals: List[Interval] ) -> int: for i, iv in enumerate(intervals): if iv is interval: return i return 0 # ------------------------------------------------------------------ # Private: traceback sampling # ------------------------------------------------------------------ def _sample_curr_interval(self, x: int) -> Interval: intervals = self._get_state_space(x) # If the state space at x is empty (degenerate branch), search backwards if not intervals: for k in sorted(self.state_spaces.keys(), reverse=True): if self.state_spaces[k]: intervals = self.state_spaces[k] x = k break fp = self.forward_probs[x] ws = sum(fp) q = self._random() w = ws * q for i, iv in enumerate(intervals): w -= fp[i] if w <= 0: self.sample_index = i return iv self.sample_index = len(intervals) - 1 return intervals[-1] def _sample_prev_interval(self, interval: Interval, x: int) -> Interval: intervals = self._get_state_space(x) if not intervals: return interval self.lower_bound = intervals[0].lb rho = self.rhos[x] self._compute_trace_back_probs(rho, interval, intervals) fp = self.forward_probs[x] ws = sum( self.trace_back_probs[i] * fp[i] for i in range(len(intervals)) if intervals[i] is not interval ) if ws <= 0: # Fallback: return the interval whose time is closest to the # current interval's representative time. This preserves # temporal continuity and avoids systematic bias. target_t = interval.time # exponential median set by fill_time() best_i = min(range(len(intervals)), key=lambda i: abs(intervals[i].time - target_t)) self.sample_index = best_i return intervals[best_i] q = self._random() w = ws * q for i, iv in enumerate(intervals): if iv is not interval: w -= self.trace_back_probs[i] * fp[i] if w <= 0: self.sample_index = i return iv self.sample_index = len(intervals) - 1 return intervals[-1] def _sample_source_interval(self, interval: Interval, x: int) -> Interval: src = self.source_interval[id(interval)] intervals = self._get_state_space(x) self.sample_index = self._get_interval_index(src, intervals) return src def _sample_recomb_interval(self, interval: Interval, x: int) -> Interval: if interval.lb == interval.ub: # Point mass: fall back to plain sampling return self._sample_curr_interval(x) intervals = self._get_state_space(x) fp = self.forward_probs[x] ws = sum( self._recomb_prob(iv.time, interval.lb, interval.ub) * fp[i] for i, iv in enumerate(intervals) ) if ws <= 0: return self._sample_curr_interval(x) q = self._random() w = ws * q for i, iv in enumerate(intervals): w -= self._recomb_prob(iv.time, interval.lb, interval.ub) * fp[i] if w <= 0: self.sample_index = i return iv self.sample_index = len(intervals) - 1 return intervals[-1] def _trace_back_helper(self, interval: Interval, x: int) -> int: """Walk backward from step x. Mirrors TSP_smc::trace_back_helper.""" y = self._get_prev_breakpoint(x) q = self._random() p = 1.0 intervals = self._get_state_space(x) if not intervals: return y # degenerate: no state, skip to previous breakpoint self.lower_bound = intervals[0].lb self.trace_back_probs = [0.0] * len(intervals) while p > q and x > y: rho = self.rhos[x - 1] self._compute_trace_back_probs(rho, interval, intervals) self.prev_rho = rho prev_fp = self.forward_probs[x - 1] all_prob = sum( self.trace_back_probs[i] * prev_fp[i] for i in range(len(intervals)) ) if all_prob <= 0: raise RuntimeError("TSP trace_back_helper: zero all_prob") non_recomb = self.trace_back_probs[self.sample_index] * prev_fp[self.sample_index] shrinkage = non_recomb / all_prob p *= shrinkage if p <= q: return x x -= 1 return y def _sample_joining_node(self, interval: Interval) -> Node: """Sample a coalescence node from *interval*. Mirrors TSP_smc::sample_joining_node (single-arg version). """ global _counter if interval.node is not None: return interval.node t = self._exp_median(interval.lb, interval.ub) n = _new_node(t) n.index = _counter _counter += 1 return n # ------------------------------------------------------------------ # Private: time sampling helpers # ------------------------------------------------------------------ def _exp_median(self, lb: float, ub: float) -> float: """Sample a time uniformly in exponential quantile space. Mirrors TSP_smc::exp_median. """ if math.isinf(ub): return lb + 2.0 * self._random() if ub - lb <= 0.005: return (0.45 + 0.1 * self._random()) * (ub - lb) + lb if lb > 10: return (0.45 + 0.1 * self._random()) * (ub - lb) + lb lq = 1.0 - math.exp(-lb) uq = 1.0 - math.exp(-ub) mq = (0.45 + 0.1 * self._random()) * (uq - lq) + lq m = -math.log(1.0 - mq) return max(lb, min(ub, m))