Source code for pysinger.hmm.coalescent

"""
CoalescentCalculator — piecewise-exponential coalescent CDF/quantile.

Mirrors Coalescent_calculator.cpp / Coalescent_calculator.hpp.

Given a set of branches spanning a range of times, computes the
probability that a new lineage coalesces in any time interval [lb, ub].
The coalescent rate at time t is the number of lineages currently alive
at t (i.e., spanning t).

The CDF is piecewise-exponential: within each interval the rate is
constant, so the survival function is exp(−rate * Δt).
"""
from __future__ import annotations

import bisect
import math
from typing import List, Set, Tuple, TYPE_CHECKING

from sortedcontainers import SortedDict

if TYPE_CHECKING:
    from ..data.branch import Branch


[docs] class CoalescentCalculator: """Piecewise-exponential coalescent CDF for a set of branches. Usage:: cc = CoalescentCalculator(cut_time=0.0) cc.compute(set_of_branches) p = cc.weight(lb, ub) # probability of coalescence in [lb, ub] t = cc.time(lb, ub) # representative time in [lb, ub] """ def __init__(self, cut_time: float) -> None: self.cut_time = cut_time # rate_changes[time] = Δrate (positive when branches start, negative when they end) self._rate_changes: SortedDict = SortedDict() # rates[time] = current coalescent rate (cumulative sum of rate_changes) self._rates: SortedDict = SortedDict() # Cumulative probabilities indexed by time (SortedDict[time → cum_prob]) self._cum_probs: SortedDict = SortedDict() # Same data as parallel lists for quantile lookup (sorted by cum_prob) self._prob_vals: List[float] = [] # cum_prob values (sorted) self._prob_times: List[float] = [] # corresponding times self.min_time: float = 0.0 self.max_time: float = math.inf # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def compute(self, branches: Set["Branch"]) -> None: """Recompute the CDF from *branches*. Mirrors Coalescent_calculator::compute. """ self._compute_rate_changes(branches) self._compute_rates() self._compute_probs_quantiles()
[docs] def weight(self, lb: float, ub: float) -> float: """Return the probability of coalescence in [lb, ub]. Mirrors Coalescent_calculator::weight. """ p = self.prob(ub) - self.prob(lb) return p
[docs] def time(self, lb: float, ub: float) -> float: """Return a representative coalescence time in [lb, ub]. Uses the exponential median; falls back to the midpoint when the interval is tiny. Mirrors Coalescent_calculator::time. """ if math.isinf(ub): return lb + math.log(2) if ub - lb < 1e-3: return 0.5 * (lb + ub) lq = self.prob(lb) uq = self.prob(ub) if uq - lq < 1e-3: return 0.5 * (lb + ub) mid = 0.5 * (lq + uq) t = self.quantile(mid) return max(lb, min(ub, t))
[docs] def prob(self, x: float) -> float: """Return cumulative coalescence probability at time *x*. Mirrors Coalescent_calculator::prob. """ if not self._cum_probs: return 0.0 if x >= self.max_time: x = self.max_time elif x <= self.min_time: return 0.0 # Exact hit if x in self._cum_probs: return self._cum_probs[x] # Interpolate in the piecewise-exponential CDF u_idx = self._cum_probs.bisect_right(x) l_idx = u_idx - 1 if l_idx < 0: return 0.0 if u_idx >= len(self._cum_probs): return self._cum_probs[self._cum_probs.keys()[-1]] l_key = self._cum_probs.keys()[l_idx] u_key = self._cum_probs.keys()[u_idx] base_prob = self._cum_probs[l_key] rate = self._rates.get(l_key, 0) if rate == 0: return base_prob delta_t = u_key - l_key delta_p = self._cum_probs[u_key] - base_prob new_delta_t = x - l_key # Interpolation formula from the C++ code: # new_delta_p = delta_p * expm1(-rate * new_delta_t) / expm1(-rate * delta_t) denom = math.expm1(-rate * delta_t) if abs(denom) < 1e-15: new_delta_p = delta_p * new_delta_t / delta_t else: new_delta_p = delta_p * math.expm1(-rate * new_delta_t) / denom return base_prob + new_delta_p
[docs] def quantile(self, p: float) -> float: """Return the time t such that prob(t) == p. Mirrors Coalescent_calculator::quantile. """ if not self._prob_vals: return self.min_time # Find the interval [l, u] where the cum_prob crosses p idx = bisect.bisect_right(self._prob_vals, p) l_idx = idx - 1 u_idx = idx if l_idx < 0: l_idx = 0 if u_idx >= len(self._prob_vals): u_idx = len(self._prob_vals) - 1 l_time = self._prob_times[l_idx] u_time = self._prob_times[u_idx] l_prob = self._prob_vals[l_idx] u_prob = self._prob_vals[u_idx] base_time = l_time rate = self._rates.get(l_time, 0) delta_t = u_time - l_time delta_p = u_prob - l_prob if delta_p < 1e-15: return base_time new_delta_p = p - l_prob # Inverse formula from C++: # new_delta_t = -log(1 - new_delta_p/delta_p * (1 - exp(-rate*delta_t))) / rate if rate == 0: new_delta_t = delta_t * new_delta_p / delta_p else: frac = new_delta_p / delta_p * (1.0 - math.exp(-rate * delta_t)) arg = 1.0 - frac if arg <= 0: return u_time new_delta_t = -math.log(arg) / rate return base_time + new_delta_t
# ------------------------------------------------------------------ # Private: CDF construction # ------------------------------------------------------------------ def _compute_rate_changes(self, branches: Set["Branch"]) -> None: """Record +1 at branch start and -1 at branch end. Mirrors C++ exactly: rate_changes[ub] -= 1 even for ub=inf, so max_time = inf for any branch reaching the root sentinel. """ self._rate_changes = SortedDict() for b in branches: lb = max(self.cut_time, b.lower_node.time) ub = b.upper_node.time self._rate_changes[lb] = self._rate_changes.get(lb, 0) + 1 self._rate_changes[ub] = self._rate_changes.get(ub, 0) - 1 if not self._rate_changes: return self.min_time = self._rate_changes.keys()[0] self.max_time = self._rate_changes.keys()[-1] def _compute_rates(self) -> None: """Cumulative sum of rate_changes → piecewise constant rate.""" self._rates = SortedDict() curr = 0 for t, delta in self._rate_changes.items(): curr += delta self._rates[t] = curr def _compute_probs_quantiles(self) -> None: """Build the piecewise CDF from the rates. Mirrors Coalescent_calculator::compute_probs_quantiles. """ self._cum_probs = SortedDict() if not self._rates: return rate_keys = list(self._rates.keys()) prev_prob = 1.0 cum_prob = 0.0 for i in range(len(rate_keys) - 1): curr_rate = self._rates[rate_keys[i]] prev_time = rate_keys[i] next_time = rate_keys[i + 1] if curr_rate > 0: next_prob = prev_prob * math.exp(-curr_rate * (next_time - prev_time)) cum_prob += prev_prob - next_prob else: next_prob = prev_prob self._cum_probs[next_time] = cum_prob prev_prob = next_prob # Sentinel at min_time with cumulative probability 0 self._cum_probs[self.min_time] = 0.0 # Build parallel arrays sorted by cum_prob for quantile lookup. # (Mirrors C++ quantiles set sorted by {cum_prob, time}.) pairs: List[Tuple[float, float]] = sorted( (cp, t) for t, cp in self._cum_probs.items() ) self._prob_vals = [cp for cp, t in pairs] self._prob_times = [t for cp, t in pairs]