Build SINGER yourself
This chapter walks you through implementing SINGER from scratch. Each step tells you exactly which class or function to write next, shows you the Python code, explains the math in plain language, and connects it all together. By the end you will have a working Bayesian ARG sampler.
The guide follows the dependency order: data structures first, then the math helpers, then the two HMMs, then the MCMC loop, and finally I/O.
Step 1 — The Node class
Goal. Represent a single vertex in the ARG.
Every genealogical tree is made of nodes connected by branches. A Node is a vertex — it could be a present-day sample (a haplotype you observed) or an ancestor where two lineages coalesced. Each node needs three pieces of information:
time— when this ancestor lived, measured in coalescent units (where 1 unit = \(N_e\) generations). Present-day samples havetime = 0. The further back in time, the larger the number. You also need a special root sentinel withtime = inf— every tree in the ARG hangs off this sentinel so you never have to special-case “which node is the root?”index— an integer label. Samples get0, 1, …, n-1. Internal nodes get larger numbers. The root sentinel gets-1.mutation_sites— a sorted map from genomic position to allele state. If a node carries the derived allele at position 500, thenmutation_sites[500] = 1. We use aSortedDictbecause we need efficient range queries later. A sentinel entry at position-1with state0ensures that any lookup before the first real site returns 0 (ancestral) without special-casing.
from dataclasses import dataclass, field
from sortedcontainers import SortedDict
@dataclass
class Node:
time: float
index: int = 0
mutation_sites: SortedDict = field(
default_factory=lambda: SortedDict({-1: 0})
)
def add_mutation(self, pos: float) -> None:
"""Record that this node carries the derived allele at pos."""
self.mutation_sites[pos] = 1
def get_state(self, pos: float) -> float:
"""Return allele state at pos (exact match only, else 0).
Mutations are point events: a mutation at position 100 does NOT
carry its state to position 101. This is just a dict lookup.
"""
return self.mutation_sites.get(pos, 0)
def write_state(self, pos: float, state: float) -> None:
"""Set the allele state; state==0 removes the entry."""
if state == 0:
if pos != -1: # never remove the sentinel
self.mutation_sites.pop(pos, None)
elif state == 1:
self.mutation_sites[pos] = state
# CRITICAL: identity-based equality and hashing
def __hash__(self):
return id(self)
def __eq__(self, other):
return self is other
Why identity semantics? In the ARG, two different coalescence events can happen at the same time. If you compared nodes by their time and index fields, you couldn’t tell them apart. By using self is other for equality and id(self) for hashing, every Node object is unique — even if two nodes have identical fields. This mirrors the C++ implementation where nodes are compared by pointer address.
Code:
pysinger/data/node.py
Step 2 — The Branch class
Goal. Represent a directed edge between two nodes.
A branch connects a child node (lower_node, closer to the present) to a parent node (upper_node, further in the past). The invariant is lower_node.time ≤ upper_node.time.
class Branch:
__slots__ = ("lower_node", "upper_node")
def __init__(self, lower_node=None, upper_node=None):
object.__setattr__(self, "lower_node", lower_node)
object.__setattr__(self, "upper_node", upper_node)
def __setattr__(self, name, value):
raise AttributeError("Branch is immutable")
def __eq__(self, other):
if not isinstance(other, Branch):
return False
return (self.lower_node is other.lower_node
and self.upper_node is other.upper_node)
def __hash__(self):
return hash((id(self.lower_node), id(self.upper_node)))
@property
def length(self):
if self.lower_node is None or self.upper_node is None:
return float("inf")
return self.upper_node.time - self.lower_node.time
def is_null(self):
return self.lower_node is None and self.upper_node is None
def __bool__(self):
return not self.is_null()
Three important design decisions:
Immutability. After construction, you cannot change which nodes a branch connects. We enforce this by overriding
__setattr__to raise an error, and usingobject.__setattr__inside__init__to bypass our own guard. This is necessary because branches are used as dictionary keys and set members — if they could change, the hash would become invalid.Identity-based hashing. Just like nodes, branches are compared by the identity of their node objects, not by the values of their fields.
__hash__usesid(self.lower_node)andid(self.upper_node). This meansBranch(nodeA, nodeB) == Branch(nodeA, nodeB)isTrueonly ifnodeAandnodeBare the exact same objects in memory.Null branch.
Branch()with both nodes asNoneis used as a sentinel throughout the codebase — for example, to mark the end of a lineage that has been traced to its boundary.bool(branch)returnsFalsefor null branches, so you can writeif branch:to check.
Code:
pysinger/data/branch.py
Step 3 — The Tree class
Goal. Represent a marginal coalescent tree at a single genomic position.
An ARG is a sequence of trees along the genome. At each position, there is one tree that describes how the samples are related. We store each tree as two dictionaries:
class Tree:
def __init__(self):
self.parents = {} # child Node → parent Node
self.children = {} # parent Node → set of child Nodes
These two dicts are redundant (one could be derived from the other), but keeping both makes lookups fast in both directions.
Adding and removing branches
Every topology change is built from two atomic operations:
def insert_branch(self, branch):
"""Add a branch to the tree."""
ln, un = branch.lower_node, branch.upper_node
self.parents[ln] = un
if un not in self.children:
self.children[un] = set()
self.children[un].add(ln)
def delete_branch(self, branch):
"""Remove a branch from the tree."""
ln, un = branch.lower_node, branch.upper_node
self.parents.pop(ln, None)
if un in self.children:
self.children[un].discard(ln)
if not self.children[un]:
del self.children[un]
Moving the tree along the genome
As you walk along the genome, the tree changes at recombination breakpoints. Each change is recorded as a Recombination object (Step 4) that lists which branches to delete and which to insert:
def forward_update(self, r):
"""Apply recombination r, moving right along the genome."""
for b in r.deleted_branches:
self.delete_branch(b)
for b in r.inserted_branches:
self.insert_branch(b)
def backward_update(self, r):
"""Undo recombination r, moving left along the genome."""
for b in r.inserted_branches:
self.delete_branch(b)
for b in r.deleted_branches:
self.insert_branch(b)
Finding the joining branch (key for MCMC)
When you remove a lineage from a tree (the core MCMC move), you need to know what happens to the remaining tree. If you remove a branch, its coalescence node disappears, and the sibling gets connected directly to the grandparent. find_joining_branch computes this:
def find_sibling(self, n):
"""Return the other child of n's parent."""
p = self.parents[n]
for c in self.children[p]:
if c is not n:
return c
return None
def find_joining_branch(self, removed_branch):
"""After removing this branch, what branch 'takes over'?
Returns Branch(sibling, grandparent).
"""
if removed_branch.is_null():
return Branch()
sibling = self.find_sibling(removed_branch.lower_node)
if sibling is None:
return Branch()
grandparent = self.parents.get(removed_branch.upper_node)
if grandparent is None:
return Branch(sibling, removed_branch.upper_node)
return Branch(sibling, grandparent)
Think of it this way: in a binary tree, every internal node has exactly two children. If you remove one child, the internal node becomes unnecessary — you can “short-circuit” the sibling directly to the grandparent. This is the reverse of a subtree-prune-and-regraft (SPR) operation.
Total branch length
def length(self):
"""Sum of all branch lengths (excluding root sentinel branches)."""
total = 0.0
for child, parent in self.parents.items():
if parent.index != -1: # skip root sentinel
total += parent.time - child.time
return total
This is used later by the rescaling step to calibrate coalescence times to the mutation rate.
Code:
pysinger/data/tree.py
Step 4 — The Recombination class
Goal. Record a single topology change at a genomic breakpoint.
Under the Sequentially Markov Coalescent (SMC), adjacent marginal trees differ by exactly one SPR operation. At position pos, one lineage detaches from its current parent and re-attaches somewhere else. A Recombination stores the “before” and “after” as two sets of branches:
class Recombination:
def __init__(self, deleted_branches=None, inserted_branches=None):
self.pos = 0.0
self.deleted_branches = set(deleted_branches or []) # exist BEFORE pos
self.inserted_branches = set(inserted_branches or []) # exist AFTER pos
# Derived fields (computed below)
self.deleted_node = None # old coalescence node being removed
self.inserted_node = None # new coalescence node being created
self.source_branch = Branch() # lineage that recombines
self.target_branch = Branch() # lineage it re-coalesces with
self.merging_branch = Branch() # sibling→grandparent after removal
self.recombined_branch = Branch() # source below recombination height
self.lower_transfer_branch = Branch() # target below new node
self.upper_transfer_branch = Branch() # target above new node
self.start_time = 0.0 # height of recombination
if self.deleted_branches or self.inserted_branches:
self._simplify_branches()
self._find_nodes()
Deriving the named branches
The SPR operation involves several named branches that the HMMs need. Here is how to identify them from the deleted/inserted sets:
_find_nodes() — identify which coalescence nodes were destroyed and created:
def _find_nodes(self):
"""The deleted_node appears as an upper_node in deleted but not inserted.
The inserted_node appears as an upper_node in inserted but not deleted."""
prev_upper = {b.upper_node for b in self.deleted_branches}
next_upper = {b.upper_node for b in self.inserted_branches}
for n in prev_upper:
if n not in next_upper:
self.deleted_node = n
for n in next_upper:
if n not in prev_upper:
self.inserted_node = n
_find_target_branch() — find the branch that was split by the new coalescence node:
def _find_target_branch(self):
"""The target branch is the deleted branch (other than source) that gets
split into two inserted branches at inserted_node.time."""
t = self.inserted_node.time
for b in self.deleted_branches:
if b == self.source_branch:
continue
if b.lower_node.time > t or b.upper_node.time < t:
continue
lower = Branch(b.lower_node, self.inserted_node)
upper = Branch(self.inserted_node, b.upper_node)
if lower in self.inserted_branches and upper in self.inserted_branches:
self.target_branch = b
return
_find_recomb_info() — compute the remaining named branches:
def _find_recomb_info(self):
dn = self.deleted_node
l, u = None, None
for b in self.deleted_branches:
if b == self.source_branch:
continue
if b.upper_node is dn:
l = self.inserted_node if b == self.target_branch else b.lower_node
elif b.lower_node is dn:
u = self.inserted_node if b == self.target_branch else b.upper_node
self.merging_branch = Branch(l, u)
self.recombined_branch = Branch(
self.source_branch.lower_node, self.inserted_node
)
# Transfer branches: how the target branch splits around inserted_node
candidate_lower = Branch(self.target_branch.lower_node, self.inserted_node)
self.lower_transfer_branch = (
candidate_lower if candidate_lower in self.inserted_branches
else self.merging_branch
)
candidate_upper = Branch(self.inserted_node, self.target_branch.upper_node)
self.upper_transfer_branch = (
candidate_upper if candidate_upper in self.inserted_branches
else self.merging_branch
)
To visualise what these branches mean, consider an SPR where lineage A detaches from its parent P and re-attaches to branch B at a new node Q:
BEFORE (deleted): AFTER (inserted):
P Q
/ \ / \
A S A B_lower
| |
... ...
B = (B_lower, B_upper) P removed; S connects to grandparent
source_branch = (A, P) recombined_branch = (A, Q)
target_branch = (B_lower, B_upper)
merging_branch = (S, grandparent)
lower_transfer = (B_lower, Q)
upper_transfer = (Q, B_upper)
Tracing a lineage through a recombination
The MCMC needs to follow a lineage as it passes through recombination breakpoints. trace_forward answers: “if I was on branch at time t before this recombination, what branch am I on after?”
def trace_forward(self, t, branch):
if not self.affect(branch):
return branch # this branch is not affected
if branch == self.source_branch:
if t >= self.start_time:
return Branch() # lineage was cut above the recombination
return self.recombined_branch
if branch == self.target_branch:
if t > self.inserted_node.time:
return self.upper_transfer_branch
return self.lower_transfer_branch
return self.merging_branch # was sibling/parent of source
Simplifying
After any modification, branches that appear in both sets cancel out:
def _simplify_branches(self):
common = self.deleted_branches & self.inserted_branches
self.deleted_branches -= common
self.inserted_branches -= common
Code:
pysinger/data/recombination.py
Step 5 — The Interval and IntervalInfo classes
Goal. Represent a single cell (branch, [lb, ub]) in the HMM state space.
The BSP and TSP are Hidden Markov Models whose states are regions of the tree. Each state is: “the new lineage coalesces on this branch at a time between lb and ub.” An Interval is one such state.
import math
class Interval:
__slots__ = (
"branch", "lb", "ub", "start_pos",
"weight", "time",
"source_weights", "source_intervals", "node",
)
def __init__(self, branch, lb, ub, start_pos):
self.branch = branch
self.lb = lb
self.ub = ub
self.start_pos = start_pos # HMM step where this interval was created
self.weight = 0.0 # coalescent probability mass in [lb, ub]
self.time = 0.0 # representative time point
self.source_weights = [] # traceback pointers (for BSP transfer)
self.source_intervals = []
self.node = None # optional Node (for TSP point masses)
fill_time() — the exponential median
Each interval needs a single representative time point for computing emission and transition probabilities. We use the exponential median: the time \(t\) where the \(\text{Exp}(1)\) CDF is halfway between \(F(lb)\) and \(F(ub)\).
The \(\text{Exp}(1)\) CDF is \(F(t) = 1 - e^{-t}\). The exponential median is:
Why not the arithmetic midpoint? Because coalescent times are exponentially distributed. Near the bottom of a branch (close to the present), times are dense; near the top, they are sparse. The exponential median picks a representative point that respects this density — it sits where the coalescent probability mass is concentrated, not at the geometric centre.
def fill_time(self):
lb, ub = self.lb, self.ub
if math.isinf(ub):
self.time = lb + math.log(2) # median of Exp shifted by lb
return
if abs(lb - ub) < 1e-3:
self.time = 0.5 * (lb + ub) # tiny interval: midpoint is fine
return
lq = 1.0 - math.exp(-lb)
uq = 1.0 - math.exp(-ub)
if uq - lq < 1e-3:
self.time = 0.5 * (lb + ub)
else:
q = 0.5 * (lq + uq)
self.time = -math.log(1.0 - q)
self.time = max(lb, min(ub, self.time)) # clamp
full(cut_time) — is this a full interval?
A “full” interval spans the entire branch above cut_time. Only full intervals participate in the BSP’s recombination weight calculation — partial intervals (created during transfer steps) don’t contribute to the recombination sum.
def full(self, cut_time):
lb_expected = max(cut_time, self.branch.lower_node.time)
ub_expected = self.branch.upper_node.time
return self.lb == lb_expected and self.ub == ub_expected
IntervalInfo is a lightweight hashable key (branch, lb, ub) used as dict keys during BSP transfer to accumulate probability mass from multiple source intervals that map to the same target region.
Code:
pysinger/data/interval.py
Step 6 — The ARG class
Goal. The central data structure — the entire Ancestral Recombination Graph.
An ARG is stored as a sorted map from genomic position to Recombination record, with sentinel entries at position 0 and \(\infty\). To get the marginal tree at any position \(x\), replay all records from 0 up to \(x\):
from sortedcontainers import SortedDict
class ARG:
def __init__(self, Ne=1.0, sequence_length=1.0):
self.Ne = Ne
self.sequence_length = sequence_length
self.root = Node(time=math.inf, index=-1) # root sentinel
self.sample_nodes = set()
self.recombinations = SortedDict()
# Sentinel records at boundaries
r0 = Recombination(); r0.pos = 0.0
self.recombinations[0.0] = r0
r_end = Recombination(); r_end.pos = float(sys.maxsize)
self.recombinations[float(sys.maxsize)] = r_end
# Mutation tracking
self.mutation_sites = SortedDict() # position → True
self.mutation_branches = {} # position → set of Branches
# HMM grid
self.coordinates = []
self.rhos = [] # per-bin recombination rates
self.thetas = [] # per-bin mutation rates
# MCMC working state
self.removed_branches = SortedDict()
self.joining_branches = SortedDict()
self.cut_tree = Tree()
self.cut_time = 0.0
self.cut_node = None
self.start = 0.0
self.end = sequence_length
def get_tree_at(self, x):
"""Replay recombinations to get the marginal tree at position x."""
tree = Tree()
for pos, r in self.recombinations.items():
if pos <= x:
tree.forward_update(r)
else:
break
return tree
Building the simplest ARG
def build_singleton_arg(self, node):
"""ARG with one sample connected to the root sentinel."""
self.add_sample(node)
branch = Branch(node, self.root)
r0 = Recombination(set(), {branch})
r0.pos = 0.0
self.recombinations[0.0] = r0
The HMM coordinate grid
The BSP and TSP operate on a discrete grid along the genome. Each grid cell is one “bin” with its own recombination rate \(\rho\) and mutation rate \(\theta\). Grid points are placed at regular intervals and at every recombination breakpoint (so that topology changes always land exactly on a grid boundary):
def discretize(self, bin_size):
self.coordinates = []
recomb_keys = list(self.recombinations.keys())
recomb_idx = 1 # skip sentinel at 0
curr_pos = 0.0
while curr_pos < self.sequence_length:
self.coordinates.append(curr_pos)
next_recomb = (recomb_keys[recomb_idx]
if recomb_idx < len(recomb_keys)
else float("inf"))
if next_recomb < curr_pos + bin_size:
curr_pos = next_recomb
recomb_idx += 1
else:
curr_pos = min(curr_pos + bin_size, self.sequence_length)
self.coordinates.append(self.sequence_length)
def compute_rhos_thetas(self, r, m):
"""Per-bin rates: rho = r * span, theta = m * span.
r and m are already in coalescent units (multiplied by Ne)."""
for i in range(len(self.coordinates) - 1):
span = self.coordinates[i + 1] - self.coordinates[i]
self.rhos.append(r * span)
self.thetas.append(m * span)
remove(cut_point) — extracting a lineage (MCMC)
This is the most complex operation. Given a cut point (pos, branch, cut_time), it traces the lineage forward and backward through recombination records, removing it from the ARG:
def remove(self, cut_point):
pos, center_branch, t = cut_point
self.cut_time = t
self.cut_node = Node(time=t, index=-2) # sentinel for the cut
# Forward pass: trace the branch to the right
prev_removed = center_branch
for each recombination r to the right of pos:
joining = tree.find_joining_branch(prev_removed)
tree.forward_update(r)
next_removed = r.trace_forward(t, prev_removed)
next_joining = tree.find_joining_branch(next_removed)
r.remove(prev_removed, next_removed, joining, next_joining, cut_node)
self.removed_branches[r.pos] = next_removed
self.joining_branches[r.pos] = next_joining
prev_removed = next_removed
# Backward pass: trace the branch to the left (symmetric)
# ...same logic using trace_backward...
At each step, r.remove(...) updates the recombination record to reflect the topology without the removed lineage. The result is two maps: removed_branches (what was taken out at each position) and joining_branches (what filled the gap).
add(joining_branches, added_branches) — threading a lineage back in
The reverse of remove. Walk through the added_branches map, updating or creating recombination records:
def add(self, new_joining_branches, added_branches):
for pos in added_branches:
if pos is at an existing recombination:
r.add(prev_added, next_added, prev_joining, next_joining, cut_node)
else:
self._new_recombination(pos, prev_added, prev_joining,
next_added, next_joining)
self._impute(new_joining_branches, added_branches) # assign mutations
_impute() — majority-rule mutation assignment
When you thread a new lineage, you create new coalescence nodes that need allele states. For each segregating site, look at the three nodes around the new coalescence: the lower node (\(s_l\)), the upper node (\(s_u\)), and the query node (\(s_0\)). The new node gets the majority-rule state:
This is the simplest parsimony assignment — whichever allele is in the majority wins. It is not exact (Fitch parsimony does better), but it is fast and sufficient for the MCMC.
def _map_mutation_branch(self, x, joining_branch, added_branch):
sl = joining_branch.lower_node.get_state(x)
su = joining_branch.upper_node.get_state(x)
s0 = added_branch.lower_node.get_state(x)
sm = 1 if (sl + su + s0 > 1) else 0
added_branch.upper_node.write_state(x, sm)
# Update mutation_branches tracking...
get_arg_length() — total branch length
Walk all marginal trees, summing tree.length() * genomic_span for each interval between recombinations. This gives the total amount of “evolutionary opportunity” for mutations in the ARG, which the rescaling step uses.
def get_arg_length(self):
tree = self.get_tree_at(0.0)
prev_pos = 0.0
total = 0.0
for r_pos, r in self.recombinations.items():
if r_pos == 0:
continue
next_pos = min(r_pos, self.sequence_length)
total += tree.length() * (next_pos - prev_pos)
if r_pos >= self.sequence_length:
break
tree.forward_update(r)
prev_pos = next_pos
return total
sample_internal_cut() — choosing where to cut
For the MCMC, you need to pick a random point in the ARG to cut. Draw a random time uniformly from [0, max_tree_height], then find the branch that spans that time:
def sample_internal_cut(self):
# Find max non-inf node time in the cut tree
max_time = max(p.time for p in self.cut_tree.parents.values()
if not math.isinf(p.time))
cut_time = random() * max_time
# Find branches spanning that time
candidates = [
Branch(child, parent)
for child, parent in self.cut_tree.parents.items()
if not math.isinf(parent.time)
and parent.time > cut_time and child.time <= cut_time
]
branch = candidates[randint(0, len(candidates) - 1)]
return (self.cut_pos, branch, cut_time)
Code:
pysinger/data/arg.py
Step 7 — The CoalescentCalculator
Goal. Compute the probability that a new lineage coalesces in a given time interval.
When you add a new lineage to an existing tree, it “falls” from the present (time 0) upward into the past, and at each moment it can coalesce with any lineage alive at that time. The rate of coalescence at time \(t\) equals the number of branches alive at \(t\):
This rate is piecewise constant: it only changes when a branch starts or ends (at node times). Between those events, the coalescence process is exponential with a constant rate. The cumulative probability of coalescing by time \(t\) is:
Because \(\lambda\) is piecewise constant, \(F\) is piecewise exponential — and we can compute it exactly.
Building the CDF
class CoalescentCalculator:
def __init__(self, cut_time):
self.cut_time = cut_time
def compute(self, branches):
self._compute_rate_changes(branches)
self._compute_rates()
self._compute_probs_quantiles()
Step 7a: Record rate changes. For each branch, the coalescence rate goes up by 1 when the branch starts (at max(cut_time, lower_node.time)) and down by 1 when it ends (at upper_node.time):
def _compute_rate_changes(self, branches):
self._rate_changes = SortedDict()
for b in branches:
lb = max(self.cut_time, b.lower_node.time)
ub = b.upper_node.time
self._rate_changes[lb] = self._rate_changes.get(lb, 0) + 1
self._rate_changes[ub] = self._rate_changes.get(ub, 0) - 1
Step 7b: Running sum gives the piecewise constant rate.
def _compute_rates(self):
self._rates = SortedDict()
curr = 0
for t, delta in self._rate_changes.items():
curr += delta
self._rates[t] = curr
Now self._rates[t] tells you: “from time \(t\) until the next rate change, the coalescence rate is this value.”
Step 7c: Piecewise exponential CDF. Walk pairs of adjacent rate-change times. Between \([t_k, t_{k+1})\) with rate \(\lambda_k\), the survival probability decays exponentially:
In words: the probability of not coalescing by time \(t_{k+1}\) equals the probability of surviving to \(t_k\) times the probability of surviving the interval \([t_k, t_{k+1})\) with rate \(\lambda_k\).
The cumulative coalescence probability grows by the mass that “coalesced” in this interval:
def _compute_probs_quantiles(self):
self._cum_probs = SortedDict()
rate_keys = list(self._rates.keys())
prev_prob = 1.0 # survival probability (starts at 1)
cum_prob = 0.0 # cumulative coalescence probability
for i in range(len(rate_keys) - 1):
rate = self._rates[rate_keys[i]]
dt = rate_keys[i + 1] - rate_keys[i]
if rate > 0:
next_prob = prev_prob * math.exp(-rate * dt)
cum_prob += prev_prob - next_prob
else:
next_prob = prev_prob
self._cum_probs[rate_keys[i + 1]] = cum_prob
prev_prob = next_prob
self._cum_probs[self.min_time] = 0.0
Interpolated CDF lookup
To get \(F(x)\) for an arbitrary time \(x\) (not just at rate-change boundaries), find the interval \([t_k, t_{k+1})\) containing \(x\) and interpolate:
This is exact, not an approximation — because the rate is constant within the interval, the CDF really is exponential there. The expm1 function (math.expm1) computes \(e^x - 1\) accurately even for small \(x\), avoiding floating-point cancellation.
def prob(self, x):
# Find interval [l_key, u_key] containing x
u_idx = self._cum_probs.bisect_right(x)
l_idx = u_idx - 1
l_key = self._cum_probs.keys()[l_idx]
u_key = self._cum_probs.keys()[u_idx]
base_prob = self._cum_probs[l_key]
rate = self._rates.get(l_key, 0)
if rate == 0:
return base_prob
delta_t = u_key - l_key
delta_p = self._cum_probs[u_key] - base_prob
new_delta_t = x - l_key
denom = math.expm1(-rate * delta_t)
if abs(denom) < 1e-15:
new_delta_p = delta_p * new_delta_t / delta_t # linear fallback
else:
new_delta_p = delta_p * math.expm1(-rate * new_delta_t) / denom
return base_prob + new_delta_p
Inverse CDF (quantile function)
Given a probability \(p\), find time \(t\) such that \(F(t) = p\). This is the inverse of the above formula:
In plain language: find which segment of the CDF contains probability \(p\), then invert the exponential formula within that segment.
def quantile(self, p):
# Find the segment where cum_prob crosses p
idx = bisect.bisect_right(self._prob_vals, p)
l_idx = idx - 1
u_idx = idx
l_time = self._prob_times[l_idx]
l_prob = self._prob_vals[l_idx]
delta_p = self._prob_vals[u_idx] - l_prob
delta_t = self._prob_times[u_idx] - l_time
rate = self._rates.get(l_time, 0)
new_delta_p = p - l_prob
if rate == 0:
new_delta_t = delta_t * new_delta_p / delta_p
else:
frac = new_delta_p / delta_p * (1.0 - math.exp(-rate * delta_t))
new_delta_t = -math.log(1.0 - frac) / rate
return l_time + new_delta_t
The two main methods
def weight(self, lb, ub):
"""Probability of coalescence in [lb, ub]."""
return self.prob(ub) - self.prob(lb)
def time(self, lb, ub):
"""Representative coalescence time in [lb, ub] (exponential median).
Find t where F(t) is midway between F(lb) and F(ub)."""
if math.isinf(ub):
return lb + math.log(2)
if ub - lb < 1e-3:
return 0.5 * (lb + ub)
lq = self.prob(lb)
uq = self.prob(ub)
if uq - lq < 1e-3:
return 0.5 * (lb + ub)
mid = 0.5 * (lq + uq)
return max(lb, min(ub, self.quantile(mid)))
weight is used to set up the initial BSP forward probabilities. time is used to pick the representative time for each interval.
Code:
pysinger/hmm/coalescent.py
Step 8 — Emission models
Goal. Compute how well the observed mutations match a proposed coalescence point.
When you propose threading a new lineage onto branch \(b\) at time \(t\), you split \(b\) into three sub-branches:
upper_node (time u)
|
ℓ_upper = u - t
|
query_node ----[t]---- split point
|
ℓ_query = t - t_q ℓ_lower = t - l
| |
lower_node (time l)
The emission probability is the likelihood ratio: how much more likely is the observed data with the new lineage threaded in, compared to without it?
BinaryEmission (used by the TSP)
No mutations in the bin (null emission):
The Poisson probability of seeing no mutations on a branch of length \(\ell\) with scaled rate \(\theta\) is \(e^{-\theta\ell}\). With the new lineage, the old branch \((\ell_l + \ell_u)\) is replaced by three sub-branches. The ratio is:
The lower and upper segments cancel in numerator and denominator (they’re the same total length), leaving only the query branch’s contribution. Longer query branches (coalescence further in the past) get penalised because they have more opportunity for mutations that weren’t observed.
class BinaryEmission:
def null_emit(self, branch, time, theta, node):
l_query = time - node.time
return math.exp(-theta * l_query)
With mutations (mutation emission):
For each segregating site, compute the majority-rule ancestral state from the three nodes around the split:
where \(s_l\) = state of the lower node, \(s_u\) = state of the upper node, \(s_0\) = state of the query node. Then count state changes on each sub-branch: \(d_l = |s_m - s_l|\), \(d_u = |s_m - s_u|\), \(d_q = |s_m - s_0|\), and for the original unsplit branch \(d_{\text{old}} = |s_l - s_u|\).
Each state change contributes a factor of \(\theta/\Delta x\) (the per-site mutation rate). The emission is a product over all sites:
The intuition: if the new lineage “explains” a mutation better (fewer total state changes), the emission is higher. If it adds unnecessary state changes, the emission is lower.
def _get_diff(self, mut_set, branch, node):
"""Returns [d_lower, d_upper, d_query, d_old]."""
d = [0, 0, 0, 0]
for x in mut_set:
sl = branch.lower_node.get_state(x)
su = branch.upper_node.get_state(x)
s0 = node.get_state(x)
sm = 1 if (sl + su + s0 > 1.5) else 0
d[0] += abs(sm - sl)
d[1] += abs(sm - su)
d[2] += abs(sm - s0)
d[3] += abs(sl - su)
return d
def mut_emit(self, branch, time, theta, bin_size, mut_set, node):
ll = time - branch.lower_node.time
lu = branch.upper_node.time - time
l0 = time - node.time
diff = self._get_diff(mut_set, branch, node)
emit = (self._calc(ll * theta, bin_size, diff[0])
* self._calc(lu * theta, bin_size, diff[1])
* self._calc(l0 * theta, bin_size, diff[2]))
old = self._calc((ll + lu) * theta, bin_size, diff[3])
return max(emit / old, 1e-20)
def _calc(self, theta, bin_size, s):
"""exp(-theta) * (theta/bin_size)^s"""
return math.exp(-theta) * (theta / bin_size) ** s
PolarEmission (used by the BSP)
Extends BinaryEmission with two refinements:
Derived allele penalty: if the query node differs from the majority (\(d_q \geq 1\)), multiply by a configurable penalty factor. This discourages placing the new lineage where it introduces unnecessary derived alleles.
Root reward: when the branch reaches the root sentinel and the majority state is ancestral while the lower node is derived, apply
ancestral_prob / (1 - ancestral_prob). This biases toward the prior expectation that the root carries the ancestral allele.
class PolarEmission(Emission):
def __init__(self, penalty=1.0, ancestral_prob=0.5):
self.penalty = penalty
self.ancestral_prob = ancestral_prob
def null_emit(self, branch, time, theta, node):
l0 = time - node.time
lu = branch.upper_node.time - time
if not math.isinf(lu):
return math.exp(-theta * l0)
else:
ll = time - branch.lower_node.time
return math.exp(-theta * (ll + l0))
Code:
pysinger/hmm/emission.py
Step 9 — The BSP (Branch Sequence Propagator)
Goal. A forward HMM that answers: which branch should the new lineage join at each genomic position?
The forward equation
At each position \(x\), every interval \(i\) carries a forward probability \(\alpha_i(x)\) representing the probability that the lineage is on branch \(i\)’s region, given all the data seen so far. Between positions, the lineage can either stay on the same branch or recombine (jump to a new branch):
where:
\(p_i = \rho \cdot (t_i - t_c) \cdot e^{-\rho(t_i - t_c)}\) is the probability that interval \(i\) recombines. Longer branches (more time between the representative time \(t_i\) and the cut time \(t_c\)) have higher recombination probability.
\(R = \sum_j p_j \cdot \alpha_j(x)\) is the total probability mass that recombines away from any interval.
\(w_i\) is the probability of landing on interval \(i\) after recombining. It’s proportional to
recomb_prob * coalescent_weightfor “full” intervals (ones spanning the entire branch), normalised to sum to 1.
class BSP:
def start(self, branches, cut_time):
"""Initialise at the left boundary."""
self.cut_time = cut_time
self.valid_branches = {b for b in branches
if b.upper_node.time > cut_time}
self.cc = CoalescentCalculator(cut_time)
self.cc.compute(self.valid_branches)
self.curr_intervals = []
initial_probs = []
for b in sorted(self.valid_branches):
lb = max(b.lower_node.time, cut_time)
ub = b.upper_node.time
iv = Interval(b, lb, ub, 0)
self.curr_intervals.append(iv)
initial_probs.append(self.cc.weight(lb, ub))
self.forward_probs = [initial_probs]
def forward(self, rho):
"""Advance by one bin (no topology change)."""
prev_fp = self.forward_probs[-1]
# Compute recombination probabilities
recomb_probs = []
for iv in self.curr_intervals:
dt = iv.time - self.cut_time
recomb_probs.append(rho * dt * math.exp(-rho * dt))
# Compute recombination weights (only full intervals contribute)
recomb_weights = []
for i, iv in enumerate(self.curr_intervals):
if iv.full(self.cut_time):
recomb_weights.append(recomb_probs[i] * iv.weight)
else:
recomb_weights.append(0.0)
ws = sum(recomb_weights)
if ws > 0:
recomb_weights = [w / ws for w in recomb_weights]
# Total recombination mass
R = sum(recomb_probs[i] * prev_fp[i]
for i in range(len(self.curr_intervals)))
# Forward update
new_fp = [
prev_fp[i] * (1 - recomb_probs[i]) + R * recomb_weights[i]
for i in range(len(self.curr_intervals))
]
self.forward_probs.append(new_fp)
Transfer at recombination breakpoints
When the tree topology changes, the state space itself changes — some branches disappear, others appear. The BSP must redistribute probability mass from old intervals to new ones.
Each old interval is classified by its branch:
Source branch (the lineage that recombines): mass below
start_timegoes torecombined_branch; mass above collapses to a point mass onmerging_branch. If the interval straddlesstart_time, split proportionally usingcc.weight().Target branch (where it re-coalesces): mass below
inserted_node.timegoes tolower_transfer_branch; mass above goes toupper_transfer_branch. Some mass may also go torecombined_branchvia_get_overwrite_prob().Other branches: if affected by the topology change, route to
merging_branch; otherwise keep the same branch.
After collecting all transfers, build new Interval objects. Apply a pruning cutoff: discard partial intervals whose probability is below self.cutoff. This prevents the state space from growing without bound.
Emission
After each forward/transfer step, multiply each interval’s probability by its emission probability (from PolarEmission), then renormalise:
def null_emit(self, theta, query_node):
fp = self.forward_probs[-1]
for i, iv in enumerate(self.curr_intervals):
fp[i] *= self.eh.null_emit(iv.branch, iv.time, theta, query_node)
ws = sum(fp)
if ws <= 0:
raise RuntimeError("BSP null_emit: forward prob sum is zero")
for i in range(len(fp)):
fp[i] /= ws
Traceback — sampling the branch path
After the full forward pass, walk backward to sample which branch the lineage joins at each position. At each step:
At the rightmost position, sample an interval proportional to its forward probability.
Walk left: at each step, compute the probability of not recombining. Accumulate a running product of these “shrinkage” factors. When a uniform random draw exceeds the product, you’ve found a recombination event — sample a new source interval proportional to
recomb_prob * forward_prob.At transfer boundaries, sample from
source_weights/source_intervals.
The result is a SortedDict[position → Branch] — the branch path.
Code:
pysinger/hmm/bsp.py
Step 10 — The TSP (Time Sequence Propagator)
Goal. Conditioned on the BSP’s sampled branch path, sample a coalescence time at each position.
Exponential-quantile grid
For a branch \([l, u]\), the TSP discretises the time axis using the \(\text{Exp}(1)\) CDF. Instead of uniformly spacing grid points, it places them at exponential quantiles:
Grid boundaries are at:
Why quantile spacing? The coalescent density is highest near the bottom of the branch (close to the present). Uniform spacing would waste most grid points in the sparse upper region. Quantile spacing automatically concentrates resolution where it matters — near the bottom. The gap parameter \(q\) (default 0.02) controls how many grid points you get: smaller \(q\) = finer grid.
def _generate_grid(self, lb, ub):
lq = 1.0 - math.exp(-lb)
uq = 1.0 - math.exp(-ub)
q = uq - lq
n = math.ceil(q / self.gap)
points = [lb]
for i in range(1, n):
p = lq + i * q / n
points.append(-math.log(1.0 - p))
points.append(ub)
return points
def _generate_intervals(self, branch, lb, ub):
"""Create Interval objects for each adjacent pair of grid points."""
points = self._generate_grid(lb, ub)
for i in range(len(points) - 1):
iv = Interval(branch, points[i], points[i + 1], self.curr_index)
iv.fill_time()
self.curr_intervals.append(iv)
self._temp.append(0.0)
The PSMC transition kernel
The TSP uses a PSMC-style transition model. Given a lineage currently at time \(s\), the probability of transitioning to interval \([t_1, t_2]\) is:
The first term is a point mass: with probability \(e^{-\rho\ell}\) (where \(\ell = 2s - \text{lower\_bound} - \text{cut\_time}\)), no recombination happens and the lineage stays at time \(s\). The integral term accounts for recombination followed by re-coalescence at a new time.
The transition matrix is tridiagonal: most of the probability mass stays at the current time or moves to adjacent cells. This allows an \(O(K)\) forward update instead of \(O(K^2)\).
The key idea behind the \(O(K)\) recursion: instead of computing all \(K \times K\) transition probabilities, use cumulative sums from below (lower_sums) and from above (upper_sums):
def forward(self, rho):
"""Advance one bin using tridiagonal PSMC kernel."""
self._compute_diagonals(rho) # stay-in-place: D[i]
self._compute_lower_diagonals(rho) # from interval above: L[i]
self._compute_upper_diagonals(rho) # from interval below: U[i]
# lower_sums[i] accumulates contributions from all intervals below i
self.lower_sums[0] = 0.0
for i in range(1, self.dim):
self.lower_sums[i] = (
self.upper_diagonals[i] * self.forward_probs[-1][i - 1]
+ self.factors[i] * self.lower_sums[i - 1]
)
# upper_sums[i] = sum of forward probs above i
self.upper_sums[-1] = 0.0
for i in range(self.dim - 2, -1, -1):
self.upper_sums[i] = (
self.forward_probs[-1][i + 1] + self.upper_sums[i + 1]
)
# The full forward update in O(K)
new_fp = []
for i in range(self.dim):
new_fp.append(
self.lower_sums[i]
+ self.diagonals[i] * self.forward_probs[-1][i]
+ self.lower_diagonals[i] * self.upper_sums[i]
)
self.forward_probs.append(new_fp)
The factors[i] array stores ratios of exponential masses between adjacent intervals: \((e^{-lb_i} - e^{-ub_i}) / (e^{-lb_{i-1}} - e^{-ub_{i-1}})\), clamped to at most 5 to prevent numerical instability.
Transfer at topology changes
Three cases:
Source → merging: The lineage was on the source branch; after the recombination, it’s on the merging branch. Collapse all probability to a point mass at
deleted_node.time— you know exactly when the coalescence happened because the deleted node is at a known time.Target → recombined: The lineage was on the target branch; now it’s on the newly created recombined branch. Expand from a point mass to the full recombined branch range, putting probability above
start_time.Regular: Overlap intervals by time between the old and new branches, transferring mass proportionally (in exponential measure).
Traceback — sampling coalescence times
Walk backward, sampling an interval at each step. Convert each sampled interval to a Node at a jittered time:
def _exp_median(self, lb, ub):
"""Sample a time uniformly in exponential quantile space."""
if math.isinf(ub):
return lb + 2.0 * random()
if ub - lb <= 0.005:
return (0.45 + 0.1 * random()) * (ub - lb) + lb
lq = 1.0 - math.exp(-lb)
uq = 1.0 - math.exp(-ub)
# Jitter around the median: 0.45 to 0.55 in quantile space
mq = (0.45 + 0.1 * random()) * (uq - lq) + lq
return max(lb, min(ub, -math.log(1.0 - mq)))
The jitter (0.45 to 0.55 instead of exactly 0.5) ensures diversity across MCMC iterations — without it, you’d always pick the same time for a given interval.
Code:
pysinger/hmm/tsp.py
Step 11 — The Threader
Goal. Combine BSP + TSP into a single threading operation.
The Threader orchestrates the full pipeline: run the BSP to choose branches, run the TSP to choose times, then insert the lineage into the ARG.
Initial threading
class Threader:
def __init__(self, cutoff=0.0, gap=0.02):
self.bsp = BSP()
self.tsp = TSP()
self.pe = PolarEmission() # BSP uses polarised emission
self.be = BinaryEmission() # TSP uses symmetric emission
def thread(self, arg, node):
"""Add a new leaf node to the ARG."""
arg.add_sample(node)
self._run_bsp(arg) # BSP forward pass
self._sample_joining_branches(arg) # BSP traceback → branches
self._run_tsp(arg) # TSP forward pass
self._sample_joining_points(arg) # TSP traceback → nodes
arg.add(self.new_joining_branches,
self.added_branches) # thread into ARG
arg.approx_sample_recombinations() # set recombination times
arg.clear_remove_info()
The BSP forward pass loop
def _run_bsp(self, arg):
self.bsp.set_emission(self.pe) # PolarEmission for branch selection
self.bsp.start(arg.start_tree, self.cut_time)
for i in range(self.start_index, self.end_index):
pos = arg.coordinates[i]
query_node = arg.get_query_node_at(pos)
# Transition step: topology change or regular forward
if pos is at a recombination breakpoint:
self.bsp.transfer(recombination)
elif pos != self.start:
self.bsp.forward(arg.rhos[i - 1])
# Collect mutations in [pos, next_pos)
next_pos = arg.coordinates[i + 1]
mut_set = {m for m in arg.mutation_sites
if pos <= m < next_pos}
# Emission step
if mut_set:
self.bsp.mut_emit(arg.thetas[i], next_pos - pos,
mut_set, query_node)
else:
self.bsp.null_emit(arg.thetas[i], query_node)
The TSP forward pass loop
Same structure, but conditioned on the BSP’s sampled joining branches. The TSP uses BinaryEmission and handles three transition types: transfer (at recombination), recombine (branch change without topology change), and forward (same branch).
MCMC re-threading with Metropolis acceptance
def internal_rethread(self, arg, cut_point):
"""MCMC move: remove a lineage and propose a new threading."""
self.cut_time = cut_point[2]
arg.remove(cut_point) # extract lineage
self._run_bsp(arg) # propose new branches
self._sample_joining_branches(arg)
self._run_tsp(arg) # propose new times
self._sample_joining_points(arg)
# Metropolis acceptance ratio
ar = self._acceptance_ratio(arg)
if random() < ar:
arg.add(self.new_joining_branches,
self.added_branches) # ACCEPT proposal
else:
arg.add(arg.joining_branches,
arg.removed_branches) # REJECT: restore original
arg.approx_sample_recombinations()
arg.clear_remove_info()
The acceptance ratio
where \(h\) is the effective tree height at the cut position — the maximum child-node time in the cut tree. If the joining branch reaches the root sentinel, use the coalescence node’s time instead.
This ratio favours proposals that produce comparable or shorter trees. Proposals that inflate the tree height get penalised (ratio < 1, lower acceptance probability). This keeps the MCMC from drifting toward unreasonably tall trees.
def _acceptance_ratio(self, arg):
cut_height = max(child.time
for child in arg.cut_tree.parents.keys())
old_height = cut_height
new_height = cut_height
# Adjust if joining branch reaches root
if old_joining_branch.upper_node is arg.root:
old_height = old_added_branch.upper_node.time
if new_joining_branch.upper_node is arg.root:
new_height = new_added_branch.upper_node.time
if new_height <= 0:
return 1.0
return old_height / new_height
Code:
pysinger/mcmc/threader.py
Step 12 — The Sampler
Goal. Top-level MCMC orchestrator.
class Sampler:
def __init__(self, Ne=1.0, recomb_rate=0.0, mut_rate=0.0):
self.Ne = Ne
self.recomb_rate = recomb_rate * Ne # scale to coalescent units
self.mut_rate = mut_rate * Ne
self.bsp_c = 0.0 # BSP pruning cutoff
self.tsp_q = 0.02 # TSP grid gap
Loading data
def load_vcf(self, vcf_file, start=0, end=float("inf"), haploid=False):
if haploid:
nodes, seq_len = read_vcf_haploid(vcf_file, start, end)
else:
nodes, seq_len = read_vcf_phased(vcf_file, start, end)
self.sequence_length = seq_len
self.sample_nodes = set(nodes)
# Shuffle threading order for better initial ARGs
self.ordered_sample_nodes = shuffled(nodes, seed=self._seed)
Building the initial ARG
def iterative_start(self, max_retries=5):
"""Thread all samples one-by-one to build an initial ARG."""
for attempt in range(max_retries):
try:
self._build_singleton_arg() # ARG with first sample
for node in self.ordered_sample_nodes[1:]:
threader = self._make_threader()
threader.thread(self.arg, node) # thread each sample
self._rescale() # calibrate times
return
except RuntimeError:
# HMM underflow: retry with different RNG
self._rng = np.random.default_rng(self._seed + attempt + 1)
The MCMC loop
def internal_sample(self, num_iters, spacing=1):
"""Run MCMC iterations. Each iteration re-threads enough lineages
to cover at least spacing * sequence_length base pairs."""
for iteration in range(num_iters):
updated_length = 0.0
while updated_length < spacing * self.arg.sequence_length:
threader = self._make_threader()
cut_point = self.arg.sample_internal_cut()
try:
threader.internal_rethread(self.arg, cut_point)
except Exception:
# Restore ARG on failure
if self.arg.joining_branches:
self.arg.add(self.arg.joining_branches,
self.arg.removed_branches)
self.arg.approx_sample_recombinations()
self.arg.clear_remove_info()
break
updated_length += (threader.end - threader.start)
self.arg.clear_remove_info()
self._rescale()
Rescaling — mutation-rate calibration
After each MCMC iteration, all coalescence times are rescaled so the expected number of mutations matches the observed count. The idea: if the ARG’s total branch length predicts more mutations than we observe, the tree is too tall — scale it down. If it predicts fewer, scale it up.
where:
\(S_{\text{obs}}\) = number of unique segregating site positions across all samples
\(\mu_{\text{scaled}} = \mu \cdot N_e\) = mutation rate in coalescent units
\(L_{\text{total}}\) = total ARG branch length from
get_arg_length()
def _rescale(self):
total_obs = len({pos for n in self.sample_nodes
for pos in n.mutation_sites if pos >= 0})
total_branch = self.arg.get_arg_length()
expected = self.mut_rate * total_branch
if expected <= 0:
return 1.0
scale = total_obs / expected
# Discover internal nodes by walking all trees
internal_nodes = []
tree = self.arg.get_tree_at(0.0)
sample_ids = {id(n) for n in self.sample_nodes}
seen = set()
for n in tree.parents.values():
if n.index >= 0 and id(n) not in seen and id(n) not in sample_ids:
seen.add(id(n))
internal_nodes.append(n)
for pos, r in self.arg.recombinations.items():
if 0 < pos < self.arg.sequence_length:
tree.forward_update(r)
for n in tree.parents.values():
if n.index >= 0 and id(n) not in seen and id(n) not in sample_ids:
seen.add(id(n))
internal_nodes.append(n)
# Rescale all internal node times
for n in internal_nodes:
n.time *= scale
# Also rescale recombination start_times
for pos, r in self.arg.recombinations.items():
if 0 < pos < self.arg.sequence_length and r.start_time > 0:
r.start_time *= scale
return scale
Code:
pysinger/sampler.py
Step 13 — VCF reader
Goal. Parse phased genotypes into Node objects.
def read_vcf_phased(vcf_file, start_pos=0.0, end_pos=float("inf")):
"""Read a phased VCF and return (nodes, sequence_length).
Creates 2 Node objects per diploid individual (one per haplotype).
"""
nodes = []
with open(vcf_file) as fh:
for line in fh:
if line.startswith("#CHROM"):
# Count individuals from header
fields = line.split()
n_ind = len(fields) - 9
nodes = [Node(time=0.0, index=i)
for i in range(2 * n_ind)]
continue
if line.startswith("#"):
continue
parts = line.split()
pos = float(parts[1])
ref, alt = parts[3], parts[4]
# Skip if outside region, duplicate, or indel
if pos < start_pos or pos > end_pos:
continue
if len(ref) > 1 or len(alt) > 1:
continue
# Parse genotypes: "0|1" -> alleles [0, 1]
genotypes = parts[9:]
gt_vals = []
for g in genotypes:
gt_vals.append(1 if g[0] == "1" else 0)
gt_vals.append(1 if (len(g) > 2 and g[2] == "1") else 0)
# Only keep segregating sites (not fixed ancestral or derived)
gt_sum = sum(gt_vals)
if 1 <= gt_sum < len(gt_vals):
rel_pos = pos - start_pos
for i, v in enumerate(gt_vals):
if v == 1:
nodes[i].add_mutation(rel_pos)
sequence_length = end_pos - start_pos
return nodes, sequence_length
Key points:
Positions are stored relative to
start_posso the ARG always starts at coordinate 0.Fixed sites (all ancestral or all derived) are skipped — they carry no information about the genealogy.
Indels and multi-allelic sites are skipped — SINGER assumes biallelic SNPs.
Code:
pysinger/io/vcf_reader.py
Step 14 — Fitch parsimony reconstruction
Goal. Assign ancestral states to internal nodes with the minimum number of mutations.
When the ARG threads a new lineage, the newly created coalescence nodes need allele states. Fitch parsimony does this optimally (minimum mutations) in two passes:
Pass 1 — Pruning (bottom-up)
Start at the leaves and work up. Each leaf reads its state from node.get_state(pos). For each internal node, merge its two children:
def _fitch_up(self, c1, c2, parent):
s1 = self.pruning_states[c1]
s2 = self.pruning_states[c2]
if s1 == s2:
self.pruning_states[parent] = s1 # children agree
elif s1 == 0.5:
self.pruning_states[parent] = s2 # c1 is ambiguous
elif s2 == 0.5:
self.pruning_states[parent] = s1 # c2 is ambiguous
else:
self.pruning_states[parent] = 0.5 # disagree: ambiguous
The value 0.5 means “ambiguous — could be either 0 or 1.” When children disagree, a mutation must have occurred on one of the branches, but we don’t know which — so we defer the decision to the top-down pass.
Pass 2 — Peeling (top-down)
Start at the root and work down, resolving ambiguities using the parent’s definite state:
def _fitch_down(self, parent, child):
if parent.index == -1:
# Root sentinel: resolve ambiguity to ancestral (0)
s = 0.0 if self.pruning_states[child] == 0.5 \
else self.pruning_states[child]
self.peeling_states[child] = s
return
sp = self.peeling_states[parent] # parent's resolved state
sc = self.pruning_states[child] # child's pruning state
if sc == 0.5:
self.peeling_states[child] = sp # inherit parent's state
else:
self.peeling_states[child] = sc # keep definite state
After both passes, write the resolved states back: node.write_state(pos, state) for every internal node.
def reconstruct(self, pos):
self.pruning_states = {}
self.peeling_states = {}
for n in self.node_set:
self._pruning_pass(n) # bottom-up
for n in self.node_set:
self._peeling_pass(n) # top-down
for n, s in self.peeling_states.items():
n.write_state(pos, s) # write back
Code:
pysinger/reconstruction/fitch.py
Step 15 — tskit export
Goal. Convert the final ARG to a tskit.TreeSequence for downstream analysis.
import tskit
def arg_to_tskit(arg, Ne=1.0):
tables = tskit.TableCollection(sequence_length=arg.sequence_length)
# 1. Discover all nodes by replaying the full ARG
node_map = {} # id(pysinger_node) -> tskit_node_id
tree = arg.get_tree_at(0.0)
def collect(tree_obj):
for child, parent in tree_obj.parents.items():
for n in (child, parent):
if n.index != -1 and id(n) not in node_map:
is_sample = n in arg.sample_nodes
tskit_id = tables.nodes.add_row(
flags=tskit.NODE_IS_SAMPLE if is_sample else 0,
time=n.time * Ne,
)
node_map[id(n)] = tskit_id
collect(tree)
for pos, r in arg.recombinations.items():
if 0 < pos < arg.sequence_length:
tree.forward_update(r)
collect(tree)
# 2. Emit edges for each tree interval
tree = arg.get_tree_at(0.0)
prev_pos = 0.0
recomb_positions = sorted(
p for p in arg.recombinations if 0 < p < arg.sequence_length
)
def emit_edges(tree_obj, left, right):
for child, parent in tree_obj.parents.items():
if parent.index == -1 or parent.time <= child.time:
continue
c_id = node_map.get(id(child))
p_id = node_map.get(id(parent))
if c_id is not None and p_id is not None:
tables.edges.add_row(
left=left, right=right, parent=p_id, child=c_id
)
for rpos in recomb_positions:
emit_edges(tree, prev_pos, rpos)
tree.forward_update(arg.recombinations[rpos])
prev_pos = rpos
emit_edges(tree, prev_pos, arg.sequence_length)
# 3. Sort and build
tables.sort()
return tables.tree_sequence()
The time conversion n.time * Ne converts from coalescent units (where 1 unit = \(N_e\) generations) back to generations, which is what tskit expects. The resulting tree sequence can be used with the full tskit API for computing diversity, TMRCA, drawing trees, etc.
Code:
pysinger/io/tskit_writer.py
Putting it all together
from pysinger import Sampler
from pysinger.io.tskit_writer import arg_to_tskit
sampler = Sampler(Ne=10000, recomb_rate=1e-8, mut_rate=1e-8)
sampler.load_vcf("data.vcf", start=0, end=1_000_000)
sampler.iterative_start()
sampler.internal_sample(num_iters=100, spacing=1)
ts = arg_to_tskit(sampler.arg, Ne=10000)
Dependency graph
Steps 1-2 (Node, Branch)
└── Step 3 (Tree)
└── Step 4 (Recombination)
└── Step 5 (Interval)
└── Step 6 (ARG)
└── Step 7 (CoalescentCalculator)
└── Step 8 (Emission models)
└── Step 9 (BSP)
└── Step 10 (TSP)
└── Step 11 (Threader)
└── Step 12 (Sampler)
Step 13 (VCF reader) ← independent, needed by Sampler
Step 14 (Fitch) ← used by ARG._impute()
Step 15 (tskit writer) ← used after sampling