Source code for pysinger.hmm.bsp

"""
BSP — Branch Sequence Propagator, the forward HMM for branch threading.

Mirrors BSP_smc.cpp / BSP_smc.hpp.

The BSP computes the HMM forward probabilities over a set of Interval
objects (each representing a (branch, time) cell).  At each genomic
position the HMM:
  1. Optionally applies a Recombination transfer (moves probability mass
     from old intervals to new ones according to the topology change).
  2. Advances by one bin (forward step), mixing staying-in-place with
     recombining to a new time.
  3. Multiplies by the emission probability (null or with mutations).

After the forward pass, sample_joining_branches() performs a traceback
to return a map of pos → Branch for threading.
"""
from __future__ import annotations

import math
from typing import Dict, List, Optional, Set, Tuple, TYPE_CHECKING

import numpy as np
from sortedcontainers import SortedDict

from ..data.branch import Branch
from ..data.interval import Interval, IntervalInfo
from .coalescent import CoalescentCalculator

if TYPE_CHECKING:
    from ..data.node import Node
    from ..data.recombination import Recombination
    from .emission import Emission


[docs] class BSP: """Branch Sequence Propagator — forward HMM for the branch dimension. State space: a list of Interval objects. Forward probabilities: forward_probs[step][interval_idx]. Mirrors BSP_smc in the C++ code. """ def __init__(self) -> None: self.cut_time: float = 0.0 self.cutoff: float = 0.0 self.check_points: Set[float] = set() self.eh: Optional["Emission"] = None self.curr_index: int = 0 self.curr_intervals: List[Interval] = [] self.valid_branches: Set[Branch] = set() # forward_probs[step] is a list of floats, one per interval self.forward_probs: List[List[float]] = [] # state_spaces: step_idx → list of Intervals at that step self.state_spaces: SortedDict = SortedDict() # int → List[Interval] self.recomb_probs: List[float] = [] self.recomb_weights: List[float] = [] self.null_emit_probs: List[float] = [] self.mut_emit_probs: List[float] = [] self.rhos: List[float] = [] self.recomb_sums: List[float] = [] self.weight_sums: List[float] = [] self.recomb_sum: float = 0.0 self.weight_sum: float = 0.0 self.dim: int = 0 self.states_change: bool = False self.sample_index: int = 0 self.cc: CoalescentCalculator = CoalescentCalculator(0.0) self.prev_rho: float = -1.0 self.prev_theta: float = -1.0 self.prev_node: Optional["Node"] = None self._rng: np.random.Generator = np.random.default_rng()
[docs] def set_rng(self, rng: np.random.Generator) -> None: self._rng = rng
def _random(self) -> float: return float(self._rng.uniform()) # ------------------------------------------------------------------ # Configuration # ------------------------------------------------------------------
[docs] def reserve_memory(self, length: int) -> None: self.forward_probs = []
[docs] def set_cutoff(self, x: float) -> None: self.cutoff = x
[docs] def set_emission(self, e: "Emission") -> None: self.eh = e
[docs] def set_check_points(self, p: Set[float]) -> None: self.check_points = p
# ------------------------------------------------------------------ # Initialisation # ------------------------------------------------------------------
[docs] def start(self, branches: Set[Branch], t: float) -> None: """Initialise the forward pass at the left boundary. *branches* is the set of branches in the starting tree. *t* is the cut time (lineage starts at time t). Mirrors BSP_smc::start. """ self.cut_time = t self.curr_index = 0 self.valid_branches = set() self.curr_intervals = [] self.rhos = [] self.recomb_sums = [] self.weight_sums = [] self.state_spaces = SortedDict() for b in branches: if b.upper_node.time > t: self.valid_branches.add(b) self.cc = CoalescentCalculator(t) self.cc.compute(self.valid_branches) temp: List[float] = [] for b in sorted(self.valid_branches, key=lambda x: x): lb = max(b.lower_node.time, t) ub = b.upper_node.time p = self.cc.weight(lb, ub) interval = Interval(b, lb, ub, self.curr_index) self.curr_intervals.append(interval) temp.append(p) self.forward_probs.append(temp) self._compute_interval_info() self.weight_sums.append(0.0) self._set_dimensions() self.state_spaces[self.curr_index] = list(self.curr_intervals) self.states_change = False
# ------------------------------------------------------------------ # Forward step # ------------------------------------------------------------------
[docs] def forward(self, rho: float) -> None: """Advance forward by one bin with recombination rate *rho*. Mirrors BSP_smc::forward. """ self.rhos.append(rho) self._compute_recomb_probs(rho) self._compute_recomb_weights(rho) self.prev_rho = rho self.curr_index += 1 # recomb_sum = sum_i(recomb_probs[i] * forward_probs[curr-1][i]) prev_fp = self.forward_probs[self.curr_index - 1] self.recomb_sum = sum( self.recomb_probs[i] * prev_fp[i] for i in range(self.dim) ) new_fp = [0.0] * self.dim for i in range(self.dim): new_fp[i] = ( prev_fp[i] * (1.0 - self.recomb_probs[i]) + self.recomb_sum * self.recomb_weights[i] ) self.forward_probs.append(new_fp) self.recomb_sums.append(self.recomb_sum) self.weight_sums.append(self.weight_sum)
# ------------------------------------------------------------------ # Transfer at recombination # ------------------------------------------------------------------
[docs] def transfer(self, r: "Recombination") -> None: """Apply topology change *r* to the interval state space. Mirrors BSP_smc::transfer. """ self.rhos.append(0.0) self.prev_rho = -1.0 self.prev_theta = -1.0 self.recomb_sums.append(0.0) self.weight_sums.append(0.0) self.sanity_check(r) self.curr_index += 1 transfer_weights: Dict[IntervalInfo, List[float]] = {} transfer_intervals: Dict[IntervalInfo, List[Interval]] = {} self._update_states(r.deleted_branches, r.inserted_branches) for i in range(len(self.curr_intervals)): self._process_interval(r, i, transfer_weights, transfer_intervals) self._add_new_branches(r, transfer_weights, transfer_intervals) self._generate_intervals(r, transfer_weights, transfer_intervals) self._set_dimensions() self.state_spaces[self.curr_index] = list(self.curr_intervals)
# ------------------------------------------------------------------ # Emission # ------------------------------------------------------------------
[docs] def null_emit(self, theta: float, query_node: Optional["Node"]) -> None: """Apply null-emission (no mutations in bin). Mirrors BSP_smc::null_emit. """ self._compute_null_emit_prob(theta, query_node) self.prev_theta = theta self.prev_node = query_node fp = self.forward_probs[self.curr_index] ws = 0.0 for i in range(self.dim): fp[i] *= self.null_emit_probs[i] ws += fp[i] if ws <= 0: raise RuntimeError("BSP null_emit: forward prob sum is zero") for i in range(self.dim): fp[i] /= ws
[docs] def mut_emit( self, theta: float, bin_size: float, mut_set: Set[float], query_node: Optional["Node"], ) -> None: """Apply mutation emission. Mirrors BSP_smc::mut_emit. """ self._compute_mut_emit_probs(theta, bin_size, mut_set, query_node) fp = self.forward_probs[self.curr_index] ws = 0.0 for i in range(self.dim): fp[i] *= self.mut_emit_probs[i] ws += fp[i] if ws <= 0: raise RuntimeError("BSP mut_emit: forward prob sum is zero") for i in range(self.dim): fp[i] /= ws
[docs] def sanity_check(self, r: "Recombination") -> None: """Zero out invalid point-mass intervals at recombination nodes. Mirrors BSP_smc::sanity_check. """ for i, interval in enumerate(self.curr_intervals): if (interval.lb == interval.ub and interval.lb == r.inserted_node.time and interval.branch != r.target_branch): self.forward_probs[self.curr_index][i] = 0.0
# ------------------------------------------------------------------ # Traceback / sampling # ------------------------------------------------------------------
[docs] def sample_joining_branches( self, start_index: int, coordinates: List[float], ) -> SortedDict: """Traceback to sample a joining-branch map. Mirrors BSP_smc::sample_joining_branches. Returns SortedDict[pos → Branch]. """ self.prev_rho = -1.0 joining_branches: SortedDict = SortedDict() x = self.curr_index pos = coordinates[x + start_index + 1] interval = self._sample_curr_interval(x) joining_branches[pos] = interval.branch while x >= 0: x = self._trace_back_helper(interval, x) pos = coordinates[x + start_index] joining_branches[pos] = interval.branch if x == 0: break elif x == interval.start_pos: x -= 1 interval = self._sample_source_interval(interval, x) else: x -= 1 interval = self._sample_prev_interval(x) self._simplify(joining_branches) return joining_branches
[docs] def avg_num_states(self) -> float: """Average number of states (intervals) per position.""" if len(self.state_spaces) <= 1: return 0.0 span = 0 count = 0.0 keys = list(self.state_spaces.keys()) for i in range(1, len(keys)): if keys[i] == _INT_MAX_IDX: break count += len(self.state_spaces[keys[i]]) * (keys[i] - keys[i - 1]) span = keys[i] return count / span if span > 0 else 0.0
# ------------------------------------------------------------------ # Private helpers: state management # ------------------------------------------------------------------ def _update_states( self, deletions: Set[Branch], insertions: Set[Branch], ) -> None: for b in deletions: if b.upper_node.time > self.cut_time: self.valid_branches.discard(b) self.states_change = True for b in insertions: if b.upper_node.time > self.cut_time: self.valid_branches.add(b) self.states_change = True def _set_dimensions(self) -> None: self.dim = len(self.curr_intervals) self.recomb_probs = [0.0] * self.dim self.recomb_weights = [0.0] * self.dim self.null_emit_probs = [0.0] * self.dim self.mut_emit_probs = [0.0] * self.dim def _compute_interval_info(self) -> None: """Update weight and representative time for each current interval.""" if self.states_change: self.cc.compute(self.valid_branches) self.states_change = False for interval in self.curr_intervals: p = self.cc.weight(interval.lb, interval.ub) t = self.cc.time(interval.lb, interval.ub) interval.assign_weight(p) interval.assign_time(t) def _get_recomb_prob(self, rho: float, t: float) -> float: """P(recombination | rho, t) = rho*(t-cut_time)*exp(-rho*(t-cut_time)).""" dt = t - self.cut_time return rho * dt * math.exp(-rho * dt) def _compute_recomb_probs(self, rho: float) -> None: if rho == self.prev_rho: return for i, interval in enumerate(self.curr_intervals): self.recomb_probs[i] = self._get_recomb_prob(rho, interval.time) def _compute_recomb_weights(self, rho: float) -> None: if rho == self.prev_rho: return for i, interval in enumerate(self.curr_intervals): if interval.full(self.cut_time): self.recomb_weights[i] = self.recomb_probs[i] * interval.weight else: self.recomb_weights[i] = 0.0 ws = sum(self.recomb_weights) self.weight_sum = ws if ws > 0: for i in range(self.dim): self.recomb_weights[i] /= ws def _compute_null_emit_prob( self, theta: float, query_node: Optional["Node"] ) -> None: if theta == self.prev_theta and query_node is self.prev_node: return for i, interval in enumerate(self.curr_intervals): self.null_emit_probs[i] = self.eh.null_emit( interval.branch, interval.time, theta, query_node ) def _compute_mut_emit_probs( self, theta: float, bin_size: float, mut_set: Set[float], query_node: Optional["Node"], ) -> None: for i, interval in enumerate(self.curr_intervals): self.mut_emit_probs[i] = self.eh.mut_emit( interval.branch, interval.time, theta, bin_size, mut_set, query_node ) # ------------------------------------------------------------------ # Private helpers: transfer # ------------------------------------------------------------------ def _add_new_branches( self, r: "Recombination", tw: Dict[IntervalInfo, List[float]], ti: Dict[IntervalInfo, List[Interval]], ) -> None: """Add recombined and merging branches to transfer maps.""" if (r.merging_branch and r.merging_branch.lower_node is not None and r.merging_branch.upper_node is not None and r.merging_branch.upper_node.time > self.cut_time): lb = max(self.cut_time, r.merging_branch.lower_node.time) ub = r.merging_branch.upper_node.time if lb <= ub: key = IntervalInfo(r.merging_branch, lb, ub) if key not in tw: tw[key] = [] ti[key] = [] if (r.recombined_branch and r.recombined_branch.lower_node is not None and r.recombined_branch.upper_node is not None and r.recombined_branch.upper_node.time > self.cut_time): lb = max(self.cut_time, r.recombined_branch.lower_node.time) ub = r.recombined_branch.upper_node.time if lb <= ub: key = IntervalInfo(r.recombined_branch, lb, ub) if key not in tw: tw[key] = [] ti[key] = [] def _transfer_helper( self, key: IntervalInfo, prev_interval: Optional[Interval], w: float, tw: Dict[IntervalInfo, List[float]], ti: Dict[IntervalInfo, List[Interval]], ) -> None: if key not in tw: tw[key] = [] ti[key] = [] if prev_interval is not None: tw[key].append(w) ti[key].append(prev_interval) def _process_interval( self, r: "Recombination", i: int, tw: Dict[IntervalInfo, List[float]], ti: Dict[IntervalInfo, List[Interval]], ) -> None: b = self.curr_intervals[i].branch if b == r.source_branch: self._process_source_interval(r, i, tw, ti) elif b == r.target_branch: self._process_target_interval(r, i, tw, ti) else: self._process_other_interval(r, i, tw, ti) def _process_source_interval( self, r: "Recombination", i: int, tw: Dict, ti: Dict, ) -> None: prev = self.curr_intervals[i] p = self.forward_probs[self.curr_index - 1][i] break_time = r.start_time point_time = r.source_branch.upper_node.time if prev.ub <= break_time: key = IntervalInfo(r.recombined_branch, prev.lb, prev.ub) self._transfer_helper(key, prev, p, tw, ti) elif prev.lb >= break_time: key = IntervalInfo(r.merging_branch, point_time, point_time) self._transfer_helper(key, prev, p, tw, ti) else: w1 = self.cc.weight(prev.lb, break_time) w2 = self.cc.weight(break_time, prev.ub) if w1 == 0 and w2 == 0: w1, w2 = 1.0, 0.0 else: total = w1 + w2 w1 = w1 / total w2 = 1.0 - w1 key1 = IntervalInfo(r.recombined_branch, prev.lb, break_time) self._transfer_helper(key1, prev, w1 * p, tw, ti) key2 = IntervalInfo(r.merging_branch, point_time, point_time) self._transfer_helper(key2, prev, w2 * p, tw, ti) def _process_target_interval( self, r: "Recombination", i: int, tw: Dict, ti: Dict, ) -> None: prev = self.curr_intervals[i] p = self.forward_probs[self.curr_index - 1][i] join_time = r.inserted_node.time if prev.lb == prev.ub == join_time: lb = max(self.cut_time, r.start_time) ub = r.recombined_branch.upper_node.time if lb <= ub: key = IntervalInfo(r.recombined_branch, lb, ub) self._transfer_helper(key, prev, p, tw, ti) elif prev.lb >= join_time: key = IntervalInfo(r.upper_transfer_branch, prev.lb, prev.ub) self._transfer_helper(key, prev, p, tw, ti) elif prev.ub <= join_time: key = IntervalInfo(r.lower_transfer_branch, prev.lb, prev.ub) self._transfer_helper(key, prev, p, tw, ti) else: w0 = self._get_overwrite_prob(r, prev.lb, prev.ub) w1 = self.cc.weight(prev.lb, join_time) w2 = self.cc.weight(join_time, prev.ub) if w1 + w2 == 0: w1 = w2 = 0.0 w0 = 1.0 else: total = w1 + w2 w1 = w1 / total w2 = 1.0 - w1 w1 *= (1.0 - w0) w2 *= (1.0 - w0) key1 = IntervalInfo(r.lower_transfer_branch, prev.lb, join_time) self._transfer_helper(key1, prev, w1 * p, tw, ti) key2 = IntervalInfo(r.upper_transfer_branch, join_time, prev.ub) self._transfer_helper(key2, prev, w2 * p, tw, ti) lb = max(r.start_time, self.cut_time) ub = r.recombined_branch.upper_node.time if lb <= ub: key3 = IntervalInfo(r.recombined_branch, lb, ub) self._transfer_helper(key3, prev, w0 * p, tw, ti) def _process_other_interval( self, r: "Recombination", i: int, tw: Dict, ti: Dict, ) -> None: prev = self.curr_intervals[i] p = self.forward_probs[self.curr_index - 1][i] if r.affect(prev.branch): key = IntervalInfo(r.merging_branch, prev.lb, prev.ub) else: key = IntervalInfo(prev.branch, prev.lb, prev.ub) self._transfer_helper(key, prev, p, tw, ti) def _get_overwrite_prob(self, r: "Recombination", lb: float, ub: float) -> float: if r.pos in self.check_points: return 0.0 join_time = r.inserted_node.time p1 = self.cc.weight(lb, ub) p2 = self.cc.weight(max(self.cut_time, r.start_time), join_time) if p1 == 0 and p2 == 0: return 1.0 return p2 / (p1 + p2) def _generate_intervals( self, r: "Recombination", tw: Dict[IntervalInfo, List[float]], ti: Dict[IntervalInfo, List[Interval]], ) -> None: """Build new curr_intervals from transfer maps.""" new_intervals: List[Interval] = [] new_fp: List[float] = [] for key, weights in tw.items(): intervals_src = ti[key] b = key.branch if b.lower_node is None or b.upper_node is None: continue lb = key.lb ub = key.ub p = sum(weights) full_lb = max(self.cut_time, b.lower_node.time) full_ub = b.upper_node.time is_full = (lb == full_lb and ub == full_ub) if is_full: new_iv = Interval(b, lb, ub, self.curr_index) new_intervals.append(new_iv) new_fp.append(p) if weights: new_iv.source_weights = list(weights) new_iv.source_intervals = list(intervals_src) elif p >= self.cutoff: new_iv = Interval(b, lb, ub, self.curr_index) new_intervals.append(new_iv) new_fp.append(p) if weights: new_iv.source_weights = list(weights) new_iv.source_intervals = list(intervals_src) self.forward_probs.append(new_fp) self.curr_intervals = new_intervals self._compute_interval_info() # ------------------------------------------------------------------ # Private helpers: state-space lookup # ------------------------------------------------------------------ def _get_state_space(self, x: int) -> List[Interval]: idx = self.state_spaces.bisect_right(x) - 1 if idx < 0: return self.curr_intervals key = self.state_spaces.keys()[idx] return self.state_spaces[key] def _get_prev_breakpoint(self, x: int) -> int: idx = self.state_spaces.bisect_right(x) - 1 if idx < 0: return 0 return self.state_spaces.keys()[idx] # ------------------------------------------------------------------ # Private helpers: traceback sampling # ------------------------------------------------------------------ def _sample_curr_interval(self, x: int) -> Interval: intervals = self._get_state_space(x) fp = self.forward_probs[x] ws = sum(fp) q = self._random() w = ws * q for i, interval in enumerate(intervals): w -= fp[i] if w <= 0: self.sample_index = i return interval # Fallback self.sample_index = len(intervals) - 1 return intervals[-1] def _sample_prev_interval(self, x: int) -> Interval: intervals = self._get_state_space(x) rho = self.rhos[x] ws = self.recomb_sums[x] q = self._random() w = ws * q for i, interval in enumerate(intervals): contrib = self._get_recomb_prob(rho, interval.time) * self.forward_probs[x][i] w -= contrib if w <= 0: self.sample_index = i return interval self.sample_index = len(intervals) - 1 return intervals[-1] def _sample_source_interval(self, interval: Interval, x: int) -> Interval: prev_intervals = self._get_state_space(x) weights = interval.source_weights sources = interval.source_intervals ws = sum(weights) q = self._random() w = ws * q for i, src in enumerate(sources): w -= weights[i] if w <= 0: try: self.sample_index = prev_intervals.index(src) except ValueError: self.sample_index = 0 return src src = sources[-1] try: self.sample_index = prev_intervals.index(src) except ValueError: self.sample_index = 0 return src def _trace_back_helper(self, interval: Interval, x: int) -> int: """Walk backward from step x, deciding when to jump to a new lineage. Mirrors BSP_smc::trace_back_helper. """ if not interval.full(self.cut_time): return interval.start_pos p = self._random() q = 1.0 while x > interval.start_pos: recomb_sum = self.recomb_sums[x - 1] weight_sum = self.weight_sums[x] if recomb_sum == 0: shrinkage = 1.0 else: rp = self._get_recomb_prob(self.rhos[x - 1], interval.time) non_recomb_prob = (1.0 - rp) * self.forward_probs[x - 1][self.sample_index] all_prob = non_recomb_prob + recomb_sum * interval.weight * rp / (weight_sum if weight_sum > 0 else 1.0) if all_prob <= 0: shrinkage = 1.0 else: shrinkage = non_recomb_prob / all_prob shrinkage = max(0.0, min(1.0, shrinkage)) q *= shrinkage if p >= q: return x x -= 1 return interval.start_pos @staticmethod def _simplify(joining_branches: SortedDict) -> None: """Deduplicate consecutive identical branches in the joining-branch map.""" if len(joining_branches) <= 1: return keys = list(joining_branches.keys()) simplified: SortedDict = SortedDict() curr = joining_branches[keys[0]] simplified[keys[0]] = curr for k in keys[1:]: if joining_branches[k] != curr: simplified[k] = joining_branches[k] curr = joining_branches[k] # Always keep the last entry simplified[keys[-1]] = joining_branches[keys[-1]] joining_branches.clear() for k, v in simplified.items(): joining_branches[k] = v
# Sentinel used in avg_num_states _INT_MAX_IDX = 10**18