Source code for pysinger.hmm.emission

"""
Emission models for the BSP/TSP forward HMM.

Mirrors Binary_emission.cpp, Polar_emission.cpp, and Emission.hpp.

The abstract base class Emission defines two methods:
  null_emit(branch, time, theta, query_node)   — no mutation in bin
  mut_emit(branch, time, theta, bin_size, mut_set, query_node) — mutations in bin

BinaryEmission — symmetric infinite-sites model (C++ Binary_emission).
PolarEmission  — polarised (ancestral/derived) model (C++ Polar_emission).
"""
from __future__ import annotations

import math
from abc import ABC, abstractmethod
from typing import Set, TYPE_CHECKING

if TYPE_CHECKING:
    from ..data.branch import Branch
    from ..data.node import Node


[docs] class Emission(ABC): """Abstract emission model for the HMM."""
[docs] @abstractmethod def null_emit( self, branch: "Branch", time: float, theta: float, node: "Node", ) -> float: """Return the emission probability when no mutations fall in the bin.""" ...
[docs] @abstractmethod def mut_emit( self, branch: "Branch", time: float, theta: float, bin_size: float, mut_set: Set[float], node: "Node", ) -> float: """Return the emission probability when mut_set mutations fall in the bin.""" ...
[docs] def emit( self, branch: "Branch", time: float, theta: float, bin_size: float, emissions: list, node: "Node", ) -> float: """Emit using pre-computed diff counts (used by TSP).""" ...
# --------------------------------------------------------------------------- # BinaryEmission # ---------------------------------------------------------------------------
[docs] class BinaryEmission(Emission): """Symmetric infinite-sites emission model. Mirrors Binary_emission.cpp. The emission probability for a branch (lower, upper) and query node with representative time *t* on the branch is: P(data) ∝ exp(-θ * l_lower) * (θ/bin_size)^s_lower * exp(-θ * l_upper) * (θ/bin_size)^s_upper * exp(-θ * l_query) * (θ/bin_size)^s_query divided by the "old" probability (without the threaded lineage) to get the ratio used in the HMM. """ def _calculate_prob(self, theta: float, bin_size: float, s: int) -> float: """P(s mutations | theta, bin_size) = exp(-theta) * (theta/bin_size)^s.""" if math.isinf(theta): return 1.0 unit_theta = theta / bin_size return math.exp(-theta) * (unit_theta ** s) def _get_diff( self, mut_set: Set[float], branch: "Branch", node: "Node", ) -> list: """Compute (s_lower, s_upper, s_query, s_old) difference counts.""" d = [0, 0, 0, 0] for x in mut_set: sl = branch.lower_node.get_state(x) su = branch.upper_node.get_state(x) s0 = node.get_state(x) sm = 1 if (sl + su + s0 > 1.5) else 0 d[0] += abs(sm - sl) d[1] += abs(sm - su) d[2] += abs(sm - s0) d[3] += abs(sl - su) return d
[docs] def null_emit( self, branch: "Branch", time: float, theta: float, node: "Node", ) -> float: ll = time - branch.lower_node.time lu = branch.upper_node.time - time l0 = time - node.time emit_prob = ( self._calculate_prob(ll * theta, 1, 0) * (self._calculate_prob(lu * theta, 1, 0) if not math.isinf(lu) else 1.0) * self._calculate_prob(l0 * theta, 1, 0) ) if not math.isinf(lu): old_prob = self._calculate_prob((ll + lu) * theta, 1, 0) else: old_prob = 1.0 return emit_prob / old_prob if old_prob > 0 else 1.0
[docs] def mut_emit( self, branch: "Branch", time: float, theta: float, bin_size: float, mut_set: Set[float], node: "Node", ) -> float: ll = time - branch.lower_node.time lu = branch.upper_node.time - time l0 = time - node.time diff = self._get_diff(mut_set, branch, node) emit_prob = ( self._calculate_prob(ll * theta, bin_size, diff[0]) * self._calculate_prob(lu * theta, bin_size, diff[1]) * self._calculate_prob(l0 * theta, bin_size, diff[2]) ) old_prob = self._calculate_prob((ll + lu) * theta, bin_size, diff[3]) if old_prob <= 0: return 1e-20 return max(emit_prob / old_prob, 1e-20)
[docs] def emit( self, branch: "Branch", time: float, theta: float, bin_size: float, emissions: list, node: "Node", ) -> float: ll = time - branch.lower_node.time lu = branch.upper_node.time - time l0 = time - node.time emit_prob = ( self._calculate_prob(ll * theta, bin_size, emissions[0]) * self._calculate_prob(lu * theta, bin_size, emissions[1]) * self._calculate_prob(l0 * theta, bin_size, emissions[2]) ) old_prob = self._calculate_prob((ll + lu) * theta, bin_size, emissions[3]) if old_prob <= 0: return 1.0 return emit_prob / old_prob
# --------------------------------------------------------------------------- # PolarEmission # ---------------------------------------------------------------------------
[docs] class PolarEmission(Emission): """Polarised emission model for ancestral/derived alleles. Mirrors Polar_emission.cpp. Additional parameters: penalty: Cost of derived alleles in the query node. ancestral_prob: Prior probability that allele is ancestral at root. """ def __init__(self, penalty: float = 1.0, ancestral_prob: float = 0.5) -> None: self.penalty = penalty self.ancestral_prob = ancestral_prob self._root_reward: float = 1.0 def _null_prob(self, theta: float) -> float: if math.isinf(theta): return 1.0 return math.exp(-theta) def _mut_prob_single(self, theta: float, bin_size: float, s: float) -> float: """Prob of *s* mutations on a branch of length theta.""" if math.isinf(theta): return 1.0 unit_theta = theta / bin_size return unit_theta ** abs(s) def _get_diff( self, m: float, branch: "Branch", node: "Node", ) -> list: """Compute directional difference vector.""" sl = branch.lower_node.get_state(m) su = branch.upper_node.get_state(m) s0 = node.get_state(m) sm = 1 if (sl + su + s0 > 1.5) else 0 # Root reward if branch.upper_node.index == -1: if sm == 0 and sl == 1: self._root_reward = self.ancestral_prob / (1.0 - self.ancestral_prob) else: self._root_reward = 1.0 else: self._root_reward = 1.0 return [sl - sm, sm - su, s0 - sm, sl - su]
[docs] def null_emit( self, branch: "Branch", time: float, theta: float, node: "Node", ) -> float: ll = time - branch.lower_node.time lu = branch.upper_node.time - time l0 = time - node.time if not math.isinf(lu): return self._null_prob(theta * l0) else: return self._null_prob(theta * (ll + l0))
[docs] def mut_emit( self, branch: "Branch", time: float, theta: float, bin_size: float, mut_set: Set[float], node: "Node", ) -> float: ll = time - branch.lower_node.time lu = branch.upper_node.time - time l0 = time - node.time emit_prob = 1.0 old_prob = 1.0 self._root_reward = 1.0 for m in mut_set: diff = self._get_diff(m, branch, node) emit_prob *= ( self._mut_prob_single(ll * theta, bin_size, diff[0]) * self._mut_prob_single(lu * theta, bin_size, diff[1]) * self._mut_prob_single(l0 * theta, bin_size, diff[2]) ) if diff[2] >= 1: emit_prob *= self.penalty old_prob *= self._mut_prob_single((ll + lu) * theta, bin_size, diff[3]) if not math.isinf(lu): emit_prob *= self._null_prob(theta * l0) else: emit_prob *= self._null_prob(theta * (l0 + ll)) if old_prob <= 0: return 1e-20 return max(emit_prob / old_prob * self._root_reward, 1e-20)
[docs] def emit( self, branch: "Branch", time: float, theta: float, bin_size: float, emissions: list, node: "Node", ) -> float: ll = time - branch.lower_node.time lu = branch.upper_node.time - time emit_prob = ( self._mut_prob_single(ll * theta, bin_size, emissions[0]) * self._mut_prob_single(lu * theta, bin_size, emissions[1]) * self._mut_prob_single((time - node.time) * theta, bin_size, emissions[2]) ) emit_prob *= self._null_prob(theta * (time - node.time)) old_prob = ( self._mut_prob_single((ll + lu) * theta, bin_size, emissions[3]) * self._null_prob((ll + lu) * theta) ) if old_prob <= 0: return 1.0 return emit_prob / old_prob