Source code for pysinger.data.tree

"""
Tree — a marginal coalescent tree at a single genomic position.

Mirrors Tree.cpp / Tree.hpp.

The Tree is represented by two dicts:
  parents:  child_node  -> parent_node
  children: parent_node -> set of child_nodes

The topology is updated forward/backward along the genome by applying
Recombination records.
"""
from __future__ import annotations

import copy
from typing import Dict, Optional, Set, TYPE_CHECKING

if TYPE_CHECKING:
    from .branch import Branch
    from .node import Node
    from .recombination import Recombination


[docs] class Tree: """Marginal coalescent tree at one genomic position. Attributes ---------- parents: child → parent mapping (all non-root nodes present as keys). children: parent → set of children (only internal nodes as keys). """ def __init__(self) -> None: self.parents: Dict["Node", "Node"] = {} self.children: Dict["Node", Set["Node"]] = {} # ------------------------------------------------------------------ # Basic branch operations # ------------------------------------------------------------------
[docs] def insert_branch(self, branch: "Branch") -> None: """Add *branch* to the tree.""" ln, un = branch.lower_node, branch.upper_node assert ln is not None and un is not None self.parents[ln] = un if un not in self.children: self.children[un] = set() self.children[un].add(ln)
[docs] def delete_branch(self, branch: "Branch") -> None: """Remove *branch* from the tree.""" ln, un = branch.lower_node, branch.upper_node assert ln is not None and un is not None self.parents.pop(ln, None) if un in self.children: ch = self.children[un] ch.discard(ln) if not ch: del self.children[un]
# ------------------------------------------------------------------ # Topology updates # ------------------------------------------------------------------
[docs] def forward_update(self, r: "Recombination") -> None: """Apply recombination *r* moving right along the genome.""" for b in r.deleted_branches: self.delete_branch(b) for b in r.inserted_branches: self.insert_branch(b)
[docs] def backward_update(self, r: "Recombination") -> None: """Undo recombination *r* moving left along the genome.""" for b in r.inserted_branches: self.delete_branch(b) for b in r.deleted_branches: self.insert_branch(b)
[docs] def internal_forward_update(self, r: "Recombination", cut_time: float) -> None: """forward_update restricted to branches with upper_node.time > cut_time.""" for b in r.deleted_branches: if b.upper_node.time > cut_time: self.delete_branch(b) for b in r.inserted_branches: if b.upper_node.time > cut_time: self.insert_branch(b)
[docs] def internal_backward_update(self, r: "Recombination", cut_time: float) -> None: """backward_update restricted to branches with upper_node.time > cut_time.""" for b in r.inserted_branches: if b.upper_node.time > cut_time: self.delete_branch(b) for b in r.deleted_branches: if b.upper_node.time > cut_time: self.insert_branch(b)
# ------------------------------------------------------------------ # Structural queries # ------------------------------------------------------------------
[docs] def find_sibling(self, n: "Node") -> Optional["Node"]: """Return the sibling of *n* (the other child of n's parent).""" p = self.parents[n] ch = self.children[p] for c in ch: if c is not n: return c return None # shouldn't happen in a binary tree
[docs] def find_joining_branch(self, removed_branch: "Branch") -> "Branch": """Return Branch(sibling, grandparent) for the removed branch. Mirrors Tree::find_joining_branch. """ from .branch import Branch if removed_branch.is_null(): return Branch() if removed_branch.upper_node not in self.parents: if removed_branch.lower_node not in self.parents: return Branch() # upper_node is the root of the current tree (has no parent) c = self.find_sibling(removed_branch.lower_node) if c is None: return Branch() return Branch(c, removed_branch.upper_node) p = self.parents[removed_branch.upper_node] if removed_branch.lower_node not in self.parents: return Branch() c = self.find_sibling(removed_branch.lower_node) if c is None: return Branch() return Branch(c, p)
# ------------------------------------------------------------------ # MCMC remove / add (used to update cut_tree in ARG) # ------------------------------------------------------------------
[docs] def remove(self, branch: "Branch", cut_node: "Node") -> None: """Remove *branch* from the tree and replace with cut lineage. Mirrors Tree::remove. After this call the tree has `cut_node` as a leaf connected to branch.lower_node, and the sibling is directly connected to branch.upper_node's parent. """ from .branch import Branch assert branch.upper_node.index >= 0, "Cannot remove branch to root sentinel" joining_branch = self.find_joining_branch(branch) sibling = self.find_sibling(branch.lower_node) parent = self.parents[branch.upper_node] sibling_branch = Branch(sibling, branch.upper_node) parent_branch = Branch(branch.upper_node, parent) cut_branch = Branch(branch.lower_node, cut_node) self.delete_branch(branch) self.delete_branch(sibling_branch) self.delete_branch(parent_branch) self.insert_branch(joining_branch) self.insert_branch(cut_branch)
[docs] def add( self, added_branch: "Branch", joining_branch: "Branch", cut_node: Optional["Node"], ) -> None: """Insert *added_branch* joining at *joining_branch*. Mirrors Tree::add. """ from .branch import Branch lower_branch = Branch(joining_branch.lower_node, added_branch.upper_node) upper_branch = Branch(added_branch.upper_node, joining_branch.upper_node) self.delete_branch(joining_branch) self.insert_branch(lower_branch) self.insert_branch(upper_branch) self.insert_branch(added_branch) if cut_node is not None: cut_branch = Branch(added_branch.lower_node, cut_node) self.delete_branch(cut_branch)
# ------------------------------------------------------------------ # Misc helpers # ------------------------------------------------------------------
[docs] def length(self) -> float: """Total branch length (excluding root sentinel branches).""" total = 0.0 for child, parent in self.parents.items(): if parent.index != -1: total += parent.time - child.time return total
[docs] def copy(self) -> "Tree": """Shallow-copy the topology dicts (nodes are shared).""" t = Tree() t.parents = dict(self.parents) t.children = {p: set(ch) for p, ch in self.children.items()} return t
[docs] def get_branches(self) -> Set["Branch"]: """Return all branches in the tree as a set of Branch objects.""" from .branch import Branch return {Branch(child, parent) for child, parent in self.parents.items()}
def __iter__(self): """Allow passing a Tree directly where a set of branches is expected.""" from .branch import Branch for child, parent in self.parents.items(): yield Branch(child, parent) def __repr__(self) -> str: return f"Tree(n_branches={len(self.parents)})"