Source code for pysinger.io.tskit_writer

"""
tskit writer — convert a pysinger ARG to a tskit TreeSequence.

Mirrors the conceptual output of SINGER's write() method, but produces
a tskit.TreeSequence that can be analysed with the tskit/msprime ecosystem.

The conversion:
  1. Add all non-root nodes (time = node.time * 2*Ne) as tskit individuals.
  2. Walk the ARG recombinations to build a list of (left, right, parent, child)
     edges spanning each genomic interval.
  3. Sort edges (required by tskit) and call tables.sort() + tables.tree_sequence().
"""
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ..data.arg import ARG

try:
    import tskit
    _HAS_TSKIT = True
except ImportError:
    _HAS_TSKIT = False


[docs] def arg_to_tskit(arg: "ARG", Ne: float = 1.0): # -> tskit.TreeSequence """Convert *arg* to a :class:`tskit.TreeSequence`. Parameters ---------- arg : ARG The pysinger ARG object. Ne : float Effective population size used to convert coalescent time units to generations (t_generations = t_coalescent * 2 * Ne). Returns ------- ts : tskit.TreeSequence """ if not _HAS_TSKIT: raise ImportError("tskit is required for arg_to_tskit(). Install it with: pip install tskit") tables = tskit.TableCollection(sequence_length=arg.sequence_length) # ------------------------------------------------------------------ # 1. Discover all nodes by walking the full tree sequence # ------------------------------------------------------------------ # arg.node_set may be empty (add_node() is not always called), so we # discover nodes by replaying all recombinations and collecting every # node that appears in a parent/child relationship. node_map = {} seen_nodes = set() # use id() to avoid hashing issues seen_by_id = {} # id(node) → node def _collect(tree_obj): for child, parent in tree_obj.parents.items(): for n in (child, parent): if n is not None and n.index != -1 and id(n) not in seen_nodes: seen_nodes.add(id(n)) seen_by_id[id(n)] = n # Walk the whole genome tree = arg.get_tree_at(0.0) _collect(tree) for pos in arg.recombinations.keys(): if 0 < pos < arg.sequence_length: r = arg.recombinations[pos] tree.forward_update(r) _collect(tree) all_nodes = sorted(seen_by_id.values(), key=lambda n: (n.time, n.index)) for n in all_nodes: if n.index == -1: continue # skip root sentinel is_sample = n in arg.sample_nodes tskit_id = tables.nodes.add_row( flags=tskit.NODE_IS_SAMPLE if is_sample else 0, time=n.time * Ne, # 1 SINGER unit = Ne generations (haploid coalescent) ) node_map[id(n)] = tskit_id # ------------------------------------------------------------------ # 2. Walk trees and emit edges # ------------------------------------------------------------------ # Replay recombinations left to right over the full sequence; at each # topology change emit edges for the current tree interval. seq_len = arg.sequence_length tree = arg.get_tree_at(0.0) recomb_positions = sorted( pos for pos in arg.recombinations.keys() if 0 < pos < seq_len ) prev_pos = 0.0 def _emit_edges(tree_obj, left: float, right: float) -> None: if left >= right: return for child, parent in tree_obj.parents.items(): if parent.index == -1: continue # skip edges to root sentinel if parent.time <= child.time: continue # skip time-ordering violations (malformed ARG state) c_id = node_map.get(id(child)) p_id = node_map.get(id(parent)) if c_id is None or p_id is None: continue tables.edges.add_row(left=left, right=right, parent=p_id, child=c_id) for rpos in recomb_positions: _emit_edges(tree, prev_pos, rpos) r = arg.recombinations[rpos] tree.forward_update(r) prev_pos = rpos _emit_edges(tree, prev_pos, seq_len) # ------------------------------------------------------------------ # 3. Sort and build tree sequence # ------------------------------------------------------------------ tables.sort() ts = tables.tree_sequence() return ts