#! /usr/bin/env python
from treeswift.Node import Node
from collections import deque
from copy import copy
from gzip import open as gopen
from math import ceil,log
from os.path import expanduser,isfile
from warnings import warn
INVALID_NEWICK = "Tree not valid Newick tree"
INVALID_NEXML = "Invalid NeXML file"
INVALID_NEXUS = "Invalid Nexus file"
EULER_GAMMA = 0.5772156649015328606065120900824024310421
[docs]
class Tree:
'''``Tree`` class'''
def __init__(self, is_rooted=True):
'''``Tree`` constructor'''
if not isinstance(is_rooted,bool):
raise TypeError("is_rooted must be a bool")
self.root = Node() # root Node object
self.is_rooted = is_rooted # boolean to see if the tree is rooted or not
def __str__(self):
'''Represent this ``Tree`` as a string
Returns:
``str``: string representation of this ``Tree`` (Newick string)
'''
return self.newick()
def __copy__(self):
'''Copy this ``Tree``
Returns:
``Tree``: A copy of this tree
'''
return self.extract_tree(None, False, False)
[docs]
def avg_branch_length(self, terminal=True, internal=True):
'''Compute the average length of the selected branches of this ``Tree``. Edges with length ``None`` will be treated as 0-length
Args:
``terminal`` (``bool``): ``True`` to include terminal branches, otherwise ``False``
``internal`` (``bool``): ``True`` to include internal branches, otherwise ``False``
Returns:
The average length of the selected branches
'''
if not isinstance(terminal, bool):
raise TypeError("terminal must be a bool")
if not isinstance(internal, bool):
raise TypeError("internal must be a bool")
if not internal and not terminal:
raise RuntimeError("Must select either internal or terminal branches (or both)")
tot = 0.; num = 0
for node in self.traverse_preorder():
if node.edge_length is not None and ((internal and not node.is_leaf()) or (terminal and node.is_leaf())):
tot += node.edge_length; num += 1
return tot/num
[docs]
def branch_lengths(self, terminal=True, internal=True):
'''Generator over the lengths of the selected branches of this ``Tree``. Edges with length ``None`` will be output as 0-length
Args:
``terminal`` (``bool``): ``True`` to include terminal branches, otherwise ``False``
``internal`` (``bool``): ``True`` to include internal branches, otherwise ``False``
'''
if not isinstance(terminal, bool):
raise TypeError("terminal must be a bool")
if not isinstance(internal, bool):
raise TypeError("internal must be a bool")
for node in self.traverse_preorder():
if (internal and not node.is_leaf()) or (terminal and node.is_leaf()):
if node.edge_length is None:
yield 0
else:
yield node.edge_length
[docs]
def closest_leaf_to_root(self):
'''Return the leaf that is closest to the root and the corresponding distance. Edges with no length will be considered to have a length of 0
Returns:
``tuple``: First value is the closest leaf to the root, and second value is the corresponding distance
'''
best = (None,float('inf')); d = {}
for node in self.traverse_preorder():
if node.edge_length is None:
d[node] = 0
else:
d[node] = node.edge_length
if not node.is_root():
d[node] += d[node.parent]
if node.is_leaf() and d[node] < best[1]:
best = (node,d[node])
return best
[docs]
def coalescence_times(self, backward=True):
'''Generator over the times of successive coalescence events
Args:
``backward`` (``bool``): ``True`` to go backward in time (i.e., leaves to root), otherwise ``False``
'''
if not isinstance(backward, bool):
raise TypeError("backward must be a bool")
yield from sorted((d for n,d in self.distances_from_root() if len(n.children) > 1), reverse=backward)
[docs]
def coalescence_waiting_times(self, backward=True):
'''Generator over the waiting times of successive coalescence events
Args:
``backward`` (``bool``): ``True`` to go backward in time (i.e., leaves to root), otherwise ``False``
'''
if not isinstance(backward, bool):
raise TypeError("backward must be a bool")
times = []; lowest_leaf_dist = float('-inf')
for n,d in self.distances_from_root():
if len(n.children) > 1:
times.append(d)
elif len(n.children) == 0 and d > lowest_leaf_dist:
lowest_leaf_dist = d
times.append(lowest_leaf_dist)
times.sort(reverse=backward)
for i in range(len(times)-1):
yield abs(times[i]-times[i+1])
[docs]
def collapse_short_branches(self, threshold):
'''Collapse internal branches (not terminal branches) with length less than or equal to ``threshold``. A branch length of ``None`` is considered 0
Args:
``threshold`` (``float``): The threshold to use when collapsing branches
'''
if not isinstance(threshold,float) and not isinstance(threshold,int):
raise RuntimeError("threshold must be an integer or a float")
elif threshold < 0:
raise RuntimeError("threshold cannot be negative")
q = deque(); q.append(self.root)
while len(q) != 0:
next = q.popleft()
if next.edge_length is None or next.edge_length <= threshold:
if next.is_root():
next.edge_length = None
elif not next.is_leaf():
parent = next.parent; parent.remove_child(next)
for c in next.children:
parent.add_child(c)
q.extend(next.children)
[docs]
def colless(self, normalize='leaves'):
'''Compute the Colless balance index of this ``Tree``. If the tree has polytomies, they will be randomly resolved
Args:
``normalize`` (``str``): How to normalize the Colless index (if at all)
* ``None`` to not normalize
* ``"leaves"`` to normalize by the number of leaves
* ``"yule"`` to normalize to the Yule model
* ``"pda"`` to normalize to the Proportional to Distinguishable Arrangements model
Returns:
``float``: Colless index (either normalized or not)
'''
t_res = copy(self); t_res.resolve_polytomies(); leaves_below = {}; n = 0; I = 0
for node in t_res.traverse_postorder():
if node.is_leaf():
leaves_below[node] = 1; n += 1
else:
cl,cr = node.children; nl = leaves_below[cl]; nr = leaves_below[cr]
leaves_below[node] = nl+nr; I += abs(nl-nr)
if normalize is None or normalize is False:
return I
elif not isinstance(normalize,str):
raise TypeError("normalize must be None or a string")
normalize = normalize.lower()
if normalize == 'leaves':
return (2.*I)/((n-1)*(n-2))
elif normalize == 'yule':
return (I - n*log(n) - n*(EULER_GAMMA-1-log(2)))/n
elif normalize == 'pda':
return I/(n**1.5)
else:
raise RuntimeError("normalize must be None, 'leaves', 'yule', or 'pda'")
[docs]
def condense(self):
'''If siblings have the same label, merge them. If they have edge lengths, the resulting ``Node`` will have the larger of the lengths'''
self.resolve_polytomies(); labels_below = {}; longest_leaf_dist = {}
for node in self.traverse_postorder():
if node.is_leaf():
labels_below[node] = [node.label]; longest_leaf_dist[node] = None
else:
labels_below[node] = set()
for c in node.children:
labels_below[node].update(labels_below[c])
d = longest_leaf_dist[c]
if c.edge_length is not None:
if d is None:
d = 0
d += c.edge_length
if node not in longest_leaf_dist or longest_leaf_dist[node] is None or (d is not None and d > longest_leaf_dist[node]):
longest_leaf_dist[node] = d
nodes = deque(); nodes.append(self.root)
while len(nodes) != 0:
node = nodes.pop()
if node.is_leaf():
continue
if len(labels_below[node]) == 1:
node.label = labels_below[node].pop(); node.children = []
if longest_leaf_dist[node] is not None:
if node.edge_length is None:
node.edge_length = 0
node.edge_length += longest_leaf_dist[node]
else:
nodes.extend(node.children)
[docs]
def contract_low_support(self, threshold, terminal=False, internal=True):
'''Contract nodes labeled by a number (e.g. branch support) below ``threshold``
Args:
``threshold`` (``float``): The support threshold to use when contracting nodes
``terminal`` (``bool``): ``True`` to include terminal branches, otherwise ``False``
``internal`` (``bool``): ``True`` to include internal branches, otherwise ``False``
'''
if not isinstance(threshold, float) and not isinstance(threshold, int):
raise TypeError("threshold must be float or int")
to_contract = []
for node in self.traverse_preorder(leaves=terminal, internal=internal):
try:
if float(str(node)) < threshold:
to_contract.append(node)
except:
pass
for node in to_contract:
node.contract()
[docs]
def drop_edge_length_at_root(self, label='OLDROOT'):
'''If the tree has a root edge, drop the edge to be a child of the root node
Args:
``label`` (``str``): The desired label of the new child
'''
if self.root.edge_length is not None:
self.root.add_child(Node(edge_length=self.root.edge_length,label=label))
self.root.edge_length = None
[docs]
def deroot(self):
'''If tree bifurcates at the root, contract an edge incident with the root to create a trifurcation at the root.'''
if self.root.num_children() != 2:
raise RuntimeError("Can only deroot a tree with degree-2 node at the root")
children = self.root.child_nodes()
if children[0].is_leaf():
to_keep, to_del = children
elif children[1].is_leaf():
to_del, to_keep = children
else:
raise RuntimeError("Can only deroot a tree where one child of root is a leaf")
if to_keep.edge_length is not None and to_del.edge_length is not None:
if to_del.edge_length is None:
to_del.edge_length = 0
if to_keep.edge_length is None:
to_keep.edge_length = 0
to_keep.edge_length += to_del.edge_length; to_del.edge_length = 0
to_del.contract(); self.is_rooted = False
[docs]
def diameter(self):
'''Compute the diameter (maximum leaf pairwise distance) of this ``Tree``
Returns:
``float``: The diameter of this Tree
'''
d = {}; best = float('-inf')
for node in self.traverse_postorder():
if node.is_leaf():
if node.is_root():
return node.edge_length
d[node] = 0
else:
dists = sorted(d[c]+c.edge_length for c in node.children)
d[node] = dists[-1]
if len(dists) > 1: # ignore unifurcations when computing max pairwise leaf dist
max_pair = dists[-1]+dists[-2]
if max_pair > best:
best = max_pair
return best
[docs]
def distance_between(self, u, v):
'''Return the distance between nodes ``u`` and ``v`` in this ``Tree``
Args:
``u`` (``Node``): Node ``u``
``v`` (``Node``): Node ``v``
Returns:
``float``: The distance between nodes ``u`` and ``v``
'''
if not isinstance(u, Node):
raise TypeError("u must be a Node")
if not isinstance(v, Node):
raise TypeError("v must be a Node")
if u == v:
return 0.
elif u == v.parent:
return v.edge_length
elif v == u.parent:
return u.edge_length
u_dists = {u:0.}; v_dists = {v:0.}
c = u; p = u.parent # u traversal
while p is not None:
u_dists[p] = u_dists[c]
if c.edge_length is not None:
u_dists[p] += c.edge_length
if p == v:
return u_dists[p]
c = p; p = p.parent
c = v; p = v.parent # v traversal
while p is not None:
v_dists[p] = v_dists[c]
if c.edge_length is not None:
v_dists[p] += c.edge_length
if p in u_dists:
return u_dists[p] + v_dists[p]
c = p; p = p.parent
raise RuntimeError("u and v are not in the same Tree")
[docs]
def distance_matrix(self, leaf_labels=False):
'''Return a distance matrix (2D dictionary) of the leaves of this ``Tree``
Args:
``leaf_labels`` (``bool``): ``True`` to have keys be labels of leaf ``Node`` objects, otherwise ``False`` to have keys be ``Node`` objects
Returns:
``dict``: Distance matrix (2D dictionary) of the leaves of this ``Tree``, where keys are labels of leaves; ``M[u][v]`` = distance from ``u`` to ``v``
'''
M = {}; leaf_dists = {}
for node in self.traverse_postorder():
if node.is_leaf():
leaf_dists[node] = [[node,0]]
else:
for c in node.children:
if c.edge_length is not None:
for i in range(len(leaf_dists[c])):
leaf_dists[c][i][1] += c.edge_length
for c1 in range(len(node.children)-1):
leaves_c1 = leaf_dists[node.children[c1]]
for c2 in range(c1+1,len(node.children)):
leaves_c2 = leaf_dists[node.children[c2]]
for i in range(len(leaves_c1)):
for j in range(len(leaves_c2)):
u,ud = leaves_c1[i]; v,vd = leaves_c2[j]; d = ud+vd
if leaf_labels:
u_key = u.label; v_key = v.label
else:
u_key = u; v_key = v
if u_key not in M:
M[u_key] = {}
M[u_key][v_key] = d
if v_key not in M:
M[v_key] = {}
M[v_key][u_key] = d
leaf_dists[node] = leaf_dists[node.children[0]]; del leaf_dists[node.children[0]]
for i in range(1,len(node.children)):
leaf_dists[node] += leaf_dists[node.children[i]]; del leaf_dists[node.children[i]]
return M
[docs]
def distances_from_parent(self, leaves=True, internal=True, unlabeled=False):
'''Generator over the node-to-parent distances of this ``Tree``; (node,distance) tuples
Args:
``leaves`` (``bool``): ``True`` to include leaves, otherwise ``False``
``internal`` (``bool``): ``True`` to include internal nodes, otherwise ``False``
``unlabeled`` (``bool``): ``True`` to include unlabeled nodes, otherwise ``False``
'''
if not isinstance(leaves, bool):
raise TypeError("leaves must be a bool")
if not isinstance(internal, bool):
raise TypeError("internal must be a bool")
if not isinstance(unlabeled, bool):
raise TypeError("unlabeled must be a bool")
if leaves or internal:
for node in self.traverse_preorder():
if ((leaves and node.is_leaf()) or (internal and not node.is_leaf())) and (unlabeled or node.label is not None):
if node.edge_length is None:
yield (node,0)
else:
yield (node,node.edge_length)
[docs]
def distances_from_root(self, leaves=True, internal=True, unlabeled=False, weighted=True):
'''Generator over the root-to-node distances of this ``Tree``; (node,distance) tuples
Args:
``leaves`` (``bool``): ``True`` to include leaves, otherwise ``False``
``internal`` (``bool``): ``True`` to include internal nodes, otherwise ``False``
``unlabeled`` (``bool``): ``True`` to include unlabeled nodes, otherwise ``False``
``weighted`` (``bool``): ``True`` to define distance as sum of edge lengths (i.e., weighted distance), or ``False`` to define distance as total number of edges (i.e., unweighted distance). If unweighted, edges with length ``None`` are counted in the height
'''
if not isinstance(leaves, bool):
raise TypeError("leaves must be a bool")
if not isinstance(internal, bool):
raise TypeError("internal must be a bool")
if not isinstance(unlabeled, bool):
raise TypeError("unlabeled must be a bool")
if leaves or internal:
d = {}
for node in self.traverse_preorder():
if node.is_root():
d[node] = 0
else:
d[node] = d[node.parent]
if weighted:
if node.edge_length is not None:
d[node] += node.edge_length
else:
d[node] += 1
if ((leaves and node.is_leaf()) or (internal and not node.is_leaf())) and (unlabeled or node.label is not None):
yield (node,d[node])
[docs]
def draw(self, show_plot=True, export_filename=None, show_labels=False, align_labels=False, label_fontsize=8, start_time=0, default_color='#000000', xlabel=None, handles=None):
'''Draw this ``Tree``
Args:
``show_plot`` (``bool``): ``True`` to show the plot, otherwise ``False``
``export_filename`` (``str``): File to which the tree figure will be exported (otherwise ``None`` to not save to file)
``show_labels`` (``bool``): ``True`` to show the leaf labels, otherwise ``False``
``align_labels`` (``bool``): ``True`` to align the leaf labels (if shown), otherwise ``False`` to just put them by their tips
``label_fontsize`` (``int``): Font size of the leaf labels (in points). 8pt = 1/9in --> 1in = 72pt
``default_color`` (``str``): The default color to use if a node doesn't have a ``color`` attribute
``xlabel`` (``str``): The label of the horizontal axis in the resulting plot
``handles`` (``list``): List of matplotlib ``Patch`` objects for a legend
'''
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import rcParams
orig = {k: rcParams[k] for k in ['axes.spines.left','axes.spines.right','axes.spines.top']}
rcParams['axes.spines.left'] = False # hide left spine
rcParams['axes.spines.right'] = False # hide right spine
rcParams['axes.spines.top'] = False # hide top spine
# compute total height needed at each node
dy = {}
for node in self.traverse_postorder():
if node.is_leaf():
dy[node] = 1
else:
dy[node] = sum(dy[child] for child in node.children)
# compute y-coordinate of each node
y = {self.root:0} # root is at y = 0
for node in self.traverse_preorder(leaves=False):
y_top = y[node] + (dy[node]/2)
for i in range(len(node.children)):
y[node.children[i]] = y_top - (dy[node.children[i]]/2)
y_top -= dy[node.children[i]]
# compute x-coordinate of each node
x = {}
for node in self.traverse_preorder():
if node.is_root():
x[node] = start_time
else:
x[node] = x[node.parent]
if node.edge_length is not None:
x[node] += node.edge_length
# compute width and height
width = 10 # arbitrary choice
if show_labels:
height_per_leaf = label_fontsize/50. # arbitrarily chose 50 to get it to look nice
else:
height_per_leaf = 5./50.
height = 2./8. # height of x-axis is around 2/8 of an inch
height += dy[self.root] * height_per_leaf # add tree height
if xlabel is not None:
height += 1./8. # height of x-label is around 1/8 of an inch
height += 4./8. # arbitrary additional padding to get it to look nice for smaller graphs
# plot tree
fig, ax = plt.subplots(figsize=(width,height))
ax.ticklabel_format(useOffset=False) # disable +- from center
ax.get_yaxis().set_visible(False) # hide y-axis
for node in self.traverse_preorder():
if hasattr(node, 'color') and node.color is not None:
curr_color = node.color
else:
curr_color = default_color
if node.edge_length is None:
el = 0
else:
el = node.edge_length
ax.plot([x[node]-el,x[node]], [y[node],y[node]], color=curr_color) # horizontal line into node
if len(node.children) > 1:
ax.plot([x[node],x[node]], [y[node.children[0]],y[node.children[-1]]], color=curr_color)
elif node.is_leaf() and show_labels:
plt.text(x[node], y[node], f" {str(node)}", fontsize=label_fontsize, verticalalignment='center', color=curr_color)
# show/export
if xlabel is not None:
plt.xlabel(xlabel)
if handles is None:
legend = None
else:
legend = plt.legend(handles=handles, loc='upper right', bbox_to_anchor=(0,1))
plt.tight_layout()
if show_plot:
plt.show()
if export_filename is not None:
if legend is None:
plt.savefig(export_filename)
else:
plt.savefig(export_filename, bbox_extra_artists=(legend,))
plt.close()
for k in orig:
rcParams[k] = orig[k]
[docs]
def edge_length_sum(self, terminal=True, internal=True):
'''Compute the sum of all selected edge lengths in this ``Tree``
Args:
``terminal`` (``bool``): ``True`` to include terminal branches, otherwise ``False``
``internal`` (``bool``): ``True`` to include internal branches, otherwise ``False``
Returns:
``float``: Sum of all selected edge lengths in this ``Tree``
'''
if not isinstance(terminal, bool):
raise TypeError("terminal must be a bool")
if not isinstance(internal, bool):
raise TypeError("internal must be a bool")
return sum(node.edge_length for node in self.traverse_preorder() if node.edge_length is not None and ((terminal and node.is_leaf()) or (internal and not node.is_leaf())))
[docs]
def find_node(self, label, leaves=True, internal=False):
'''Find and return the node(s) labeled by ``label``, or ``None`` if none exist. Note that this function performs a linear-time search, so if you will be call
Args:
``label`` (``str``): The label to search for
``leaves`` (``bool``): ``True`` to include leaves, otherwise ``False``
``internal`` (``bool``): ``True`` to include internal nodes, otherwise ``False``
Returns:
The ``Node`` object labeled by ``label`` (or a ``list`` of ``Node`` objects if multiple are labeled by ``label``), or ``None`` if none exist
'''
out = [node for node in self.traverse_preorder(leaves=leaves, internal=internal) if node.label == label]
if len(out) == 0:
return None
elif len(out) == 1:
return out[0]
else:
return out
[docs]
def furthest_from_root(self):
'''Return the ``Node`` that is furthest from the root and the corresponding distance. Edges with no length will be considered to have a length of 0
Returns:
``tuple``: First value is the furthest ``Node`` from the root, and second value is the corresponding distance
'''
best = (self.root,0); d = {}
for node in self.traverse_preorder():
if node.edge_length is None:
d[node] = 0
else:
d[node] = node.edge_length
if not node.is_root():
d[node] += d[node.parent]
if d[node] > best[1]:
best = (node,d[node])
return best
[docs]
def gamma_statistic(self):
'''Compute the Gamma statistic of Pybus and Harvey (2000)
Returns:
``float``: The Gamma statistic of Pybus and Harvey (2000)
'''
t = copy(self); t.resolve_polytomies() # need fully bifurcating tree
G = list(t.coalescence_times(backward=False))
n = len(G)+1
if n <= 2:
raise RuntimeError("Gamma statistic can only be computed on trees with more than 2 leaves")
T = sum((j+2)*g for j,g in enumerate(G))
out = 0.
for i in range(len(G)-1):
for k in range(i+1):
out += (k+2)*G[k]
out /= (n-2)
out -= (T/2)
out /= T
out /= (1./(12*(n-2)))**0.5
return out
[docs]
def height(self, weighted=True):
'''Compute the height (i.e., maximum distance from root) of this ``Tree``
Returns:
``float``: The height (i.e., maximum distance from root) of this ``Tree``
'''
return max(d[1] for d in self.distances_from_root(weighted=weighted))
[docs]
def indent(self, space=4):
'''Return an indented Newick string, just like ``nw_indent`` in Newick Utilities
Args:
``space`` (``int``): The number of spaces a tab should equal
Returns:
``str``: An indented Newick string
'''
if not isinstance(space,int):
raise TypeError("space must be an int")
if space < 0:
raise ValueError("space must be a non-negative integer")
space = ' '*space; o = []; l = 0
for c in self.newick():
if c == '(':
o.append('(\n'); l += 1; o.append(space*l)
elif c == ')':
o.append('\n'); l -= 1; o.append(space*l); o.append(')')
elif c == ',':
o.append(',\n'); o.append(space*l)
else:
o.append(c)
return ''.join(o)
[docs]
def label_to_node(self, selection='leaves'):
'''Return a dictionary mapping labels (strings) to ``Node`` objects
* If ``selection`` is ``"all"``, the dictionary will contain all nodes
* If ``selection`` is ``"leaves"``, the dictionary will only contain leaves
* If ``selection`` is ``"internal"``, the dictionary will only contain internal nodes
* If ``selection`` is a ``set``, the dictionary will contain all nodes labeled by a label in ``selection``
* If multiple nodes are labeled by a given label, only the last (preorder traversal) will be obtained
Args:
``selection`` (``str`` or ``set``): The selection of nodes to get
* ``"all"`` to select all nodes
* ``"leaves"`` to select leaves
* ``"internal"`` to select internal nodes
* A ``set`` of labels to specify nodes to select
Returns:
``dict``: Dictionary mapping labels to the corresponding nodes
'''
if not isinstance(selection,set) and not isinstance(selection,list) and (not isinstance(selection,str) or not (selection != 'all' or selection != 'leaves' or selection != 'internal')):
raise RuntimeError('"selection" must be one of the strings "all", "leaves", or "internal", or it must be a set containing Node labels')
if isinstance(selection, str):
selection = selection[0]
elif isinstance(selection,list):
selection = set(selection)
label_to_node = {str(node): node for node in self.traverse_preorder() if selection == 'a' or (selection == 'i' and not node.is_leaf()) or (selection == 'l' and node.is_leaf()) or str(node) in selection}
if not isinstance(selection,str) and len(label_to_node) != len(selection):
warn("Not all given labels exist in the tree")
return label_to_node
[docs]
def labels(self, leaves=True, internal=True):
'''Generator over the (non-``None``) ``Node`` labels of this ``Tree``
Args:
``leaves`` (``bool``): ``True`` to include leaves, otherwise ``False``
``internal`` (``bool``): ``True`` to include internal nodes, otherwise ``False``
'''
if not isinstance(leaves, bool):
raise TypeError("leaves must be a bool")
if not isinstance(internal, bool):
raise TypeError("internal must be a bool")
for node in self.traverse_preorder():
if node.label is not None and ((leaves and node.is_leaf()) or (internal and not node.is_leaf())):
yield node.label
[docs]
def ladderize(self, ascending=True):
'''Ladderize this ``Tree`` by sorting each ``Node``'s children by total number of descendants
Args:
``ascending`` (``bool``): ``True`` to sort in ascending order of ``mode``, otherwise ``False``
'''
self.order('num_descendants_then_edge_length_then_label', ascending=ascending)
[docs]
def lineages_through_time(self, present_day=None, show_plot=True, export_filename=None, color='#000000', xmin=None, xmax=None, ymin=None, ymax=None, title=None, xlabel=None, ylabel=None):
'''Compute the number of lineages through time. If seaborn is installed, a plot is shown as well
Args:
``present_day`` (``float``): The time of the furthest node from the root. If ``None``, the top of the tree will be placed at time 0
``show_plot`` (``bool``): ``True`` to show the plot, otherwise ``False`` to only return the dictionary. To plot multiple LTTs on the same figure, set ``show_plot`` to False for all but the last plot
``export_filename`` (``str``): File to which the LTT figure will be exported (otherwise ``None`` to not save to file)
``color`` (``str``): The color of the resulting plot
``title`` (``str``): The title of the resulting plot
``xmin`` (``float``): The minimum value of the horizontal axis in the resulting plot
``xmax`` (``float``): The maximum value of the horizontal axis in the resulting plot
``xlabel`` (``str``): The label of the horizontal axis in the resulting plot
``ymin`` (``float``): The minimum value of the vertical axis in the resulting plot
``ymax`` (``float``): The maximum value of the vertical axis in the resulting plot
``ylabel`` (``str``): The label of the vertical axis in the resulting plot
Returns:
``dict``: A dictionary in which each ``(t,n)`` pair denotes the number of lineages ``n`` that existed at time ``t``
'''
if present_day is not None and not isinstance(present_day,int) and not isinstance(present_day,float):
raise TypeError("present_day must be a float")
time = {}
if self.root.edge_length is None:
tmproot = self.root
else:
tmproot = Node(); tmproot.add_child(self.root)
for node in tmproot.traverse_preorder():
if node.is_root():
time[node] = 0.
else:
time[node] = time[node.parent]
if node.edge_length is not None:
time[node] += node.edge_length
nodes = sorted((time[node],node) for node in time)
lineages = {nodes[0][0]:0}
for i in range(len(nodes)):
if nodes[i][0] not in lineages:
lineages[nodes[i][0]] = lineages[nodes[i-1][0]]
if nodes[i][1].edge_length is not None:
if nodes[i][1].edge_length >= 0:
lineages[nodes[i][0]] -= 1
else:
lineages[nodes[i][0]] += 1
for c in nodes[i][1].children:
if c.edge_length >= 0:
lineages[nodes[i][0]] += 1
else:
lineages[nodes[i][0]] -= 1
if present_day is not None:
shift = present_day - max(lineages.keys())
else:
shift = max(0,-min(lineages.keys()))
if shift != 0:
tmp = {t+shift: lineages[t] for t in lineages}
lineages = tmp
if tmproot != self.root:
self.root.parent = None
try:
plot_ltt(lineages, show_plot=show_plot, export_filename=export_filename, color=color, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, title=title, xlabel=xlabel, ylabel=ylabel)
except Exception as e:
warn("Unable to produce visualization (but dictionary will still be returned)"); print(e)
return lineages
ltt = lineages_through_time # shorthand alias
[docs]
def mrca(self, labels):
'''Return the Node that is the MRCA of the nodes labeled by a label in ``labels``. If multiple nodes are labeled by a given label, only the last (preorder traversal) will be obtained
Args:
``labels`` (``set``): Set of leaf labels
Returns:
``Node``: The MRCA of the ``Node`` objects labeled by a label in ``labels``
'''
if not isinstance(labels,set):
try:
labels = set(labels)
except:
raise TypeError("labels must be iterable")
l2n = self.label_to_node(labels)
count = {}; q = deque(l2n.values())
while len(q) != 0:
curr = q.popleft()
if curr not in count:
count[curr] = 0
count[curr] += 1
if count[curr] == len(l2n):
return curr
if curr.parent is not None:
q.append(curr.parent)
raise RuntimeError("There somehow does not exist an MRCA for the given labels")
[docs]
def mrca_matrix(self):
'''Return a dictionary storing all pairwise MRCAs. ``M[u][v]`` = MRCA of nodes ``u`` and ``v``. Excludes ``M[u][u]`` because MRCA of node and itself is itself
Returns:
``dict``: ``M[u][v]`` = MRCA of nodes ``u`` and ``v``
'''
M = {}
leaves_below = {}
for node in self.traverse_postorder():
leaves_below[node] = []
if node.is_leaf():
leaves_below[node].append(node); M[node] = {}
else:
for i in range(len(node.children)-1):
for l1 in leaves_below[node.children[i]]:
leaves_below[node].append(l1)
for j in range(i+1, len(node.children)):
for l2 in leaves_below[node.children[j]]:
M[l1][l2] = node; M[l2][l1] = node
if len(node.children) != 1:
for l2 in leaves_below[node.children[-1]]:
leaves_below[node].append(l2)
return M
[docs]
def newick(self):
'''Output this ``Tree`` as a Newick string
Returns:
``str``: Newick string of this ``Tree``
'''
suffix = ''
if hasattr(self.root, 'node_params'):
suffix += f'[{str(self.root.node_params)}]'
if self.root.edge_length is not None or hasattr(self.root, 'edge_params'):
suffix += ':'
if hasattr(self.root, 'edge_params'):
suffix += f'[{str(self.root.edge_params)}]'
if isinstance(self.root.edge_length, float) and self.root.edge_length.is_integer():
suffix += str(int(self.root.edge_length))
elif self.root.edge_length is not None:
suffix += str(self.root.edge_length)
suffix += ';'
if self.is_rooted:
return f'[&R] {self.root.newick()}{suffix}'
else:
return f'{self.root.newick()}{suffix}'
[docs]
def num_cherries(self):
'''Returns the number of cherries (i.e., internal nodes that only have leaf children) in this ``Tree``
Returns:
``int``: The number of cherries in this ``Tree``
'''
return sum(sum(not child.is_leaf() for child in node.children) == 0 for node in self.traverse_internal())
[docs]
def num_lineages_at(self, distance):
'''Returns the number of lineages of this ``Tree`` that exist ``distance`` away from the root
Args:
``distance`` (``float``): The distance away from the root
Returns:
``int``: The number of lineages that exist ``distance`` away from the root
'''
if not isinstance(distance, float) and not isinstance(distance, int):
raise TypeError("distance must be an int or a float")
if distance < 0:
raise RuntimeError("distance cannot be negative")
d = {}; q = deque(); q.append(self.root); count = 0
while len(q) != 0:
node = q.popleft()
if node.is_root():
d[node] = 0
else:
d[node] = d[node.parent]
if node.edge_length is not None:
d[node] += node.edge_length
if d[node] < distance:
q.extend(node.children)
elif node.parent is None or d[node.parent] < distance:
count += 1
return count
[docs]
def num_nodes(self, leaves=True, internal=True):
'''Compute the total number of selected nodes in this ``Tree``
Args:
``leaves`` (``bool``): ``True`` to include leaves, otherwise ``False``
``internal`` (``bool``): ``True`` to include internal nodes, otherwise ``False``
Returns:
``int``: The total number of selected nodes in this ``Tree``
'''
return self.root.num_nodes(leaves=leaves, internal=internal)
[docs]
def order(self, mode, ascending=True):
'''Order the children of the nodes in this ``Tree`` based on ``mode``
Args:
``mode`` (``str``): How to order the children of the nodes of this ``Tree``
* ``"edge_length"`` = order by incident edge length
* ``"edge_length_then_label"`` = order by incident edge length, then by node label
* ``"edge_length_then_label_then_num_descendants"`` = order by incident edge length, then by node label, then by number of descendants
* ``"edge_length_then_num_descendants"`` = order by incident edge length, then by number of descendants
* ``"edge_length_then_num_descendants_then_label"`` = order by incident edge length, then by number of descendants, then by node label
* ``"label"`` = order by node label
* ``"label_then_edge_length"`` = order by node label, then by incident edge length
* ``"label_then_edge_length_then_num_descendants"`` = order by node label, then by incident edge length, then by number of descendants
* ``"label_then_num_descendants"`` = order by node label, then by number of descendants
* ``"label_then_num_descendants_then_edge_length"`` = order by node label, then by number of descendants, then by incident edge length
* ``"num_descendants"`` = order by number of descendants
* ``"num_descendants_then_label"`` = order by number of descendants, then by node label
* ``"num_descendants_then_label_then_edge_length"`` = order by number of descendants, then by node label, then by incident edge length
* ``"num_descendants_then_edge_length"`` = order by number of descendants, then by incident edge length
* ``"num_descendants_then_edge_length_then_label"`` = order by number of descendants, then by incident edge length, then by node label
``ascending`` (``bool``): ``True`` to sort in ascending order of ``mode``, otherwise ``False``
'''
if not isinstance(mode, str):
raise TypeError("mode must be a str")
if not isinstance(ascending, bool):
raise TypeError("ascending must be a bool")
if 'num_descendants' in mode:
num_descendants = {}
for node in self.traverse_postorder():
if node.is_leaf():
num_descendants[node] = 0
else:
num_descendants[node] = sum(num_descendants[c] for c in node.children) + len(node.children)
if mode == 'edge_length':
k = lambda node: (node.edge_length is not None, node.edge_length)
elif mode == 'edge_length_then_label':
k = lambda node: (node.edge_length is not None, node.edge_length, node.label is not None, node.label)
elif mode == 'edge_length_then_label_then_num_descendants':
k = lambda node: (node.edge_length is not None, node.edge_length, node.label is not None, node.label, num_descendants[node])
elif mode == 'edge_length_then_num_descendants':
k = lambda node: (node.edge_length is not None, node.edge_length, num_descendants[node])
elif mode == 'edge_length_then_num_descendants_then_label':
k = lambda node: (node.edge_length is not None, node.edge_length, num_descendants[node], node.label is not None, node.label)
elif mode == 'label':
k = lambda node: (node.label is not None, node.label)
elif mode == 'label_then_edge_length':
k = lambda node: (node.label is not None, node.label, node.edge_length is not None, node.edge_length)
elif mode == 'label_then_edge_length_then_num_descendants':
k = lambda node: (node.label is not None, node.label, node.edge_length is not None, node.edge_length, num_descendants[node])
elif mode == 'label_then_num_descendants':
k = lambda node: (node.label is not None, node.label, num_descendants[node])
elif mode == 'label_then_num_descendants_then_edge_length':
k = lambda node: (node.label is not None, node.label, num_descendants[node], node.edge_length is not None, node.edge_length)
elif mode == 'num_descendants':
k = lambda node: num_descendants[node]
elif mode == 'num_descendants_then_label':
k = lambda node: (num_descendants[node], node.label is not None, node.label)
elif mode == 'num_descendants_then_label_then_edge_length':
k = lambda node: (num_descendants[node], node.label is not None, node.label, node.edge_length is not None, node.edge_length)
elif mode == 'num_descendants_then_edge_length':
k = lambda node: (num_descendants[node], node.edge_length is not None, node.edge_length)
elif mode == 'num_descendants_then_edge_length_then_label':
k = lambda node: (num_descendants[node], node.edge_length is not None, node.edge_length, node.label is not None, node.label)
else:
raise ValueError("Invalid choice for mode")
for node in self.traverse_preorder():
node.children.sort(key=k, reverse=not ascending)
[docs]
def rename_nodes(self, renaming_map):
'''Rename nodes in this ``Tree``
Args:
``renaming_map`` (``dict``): A dictionary mapping old labels (keys) to new labels (values)
'''
if not isinstance(renaming_map, dict):
raise TypeError("renaming_map must be a dict")
for node in self.traverse_preorder():
if node.label in renaming_map:
node.label = renaming_map[node.label]
[docs]
def reroot(self, node, length=None, branch_support=False):
'''Reroot this ``Tree`` at ``length`` up the incident edge of ``node``. If 0 or ``None``, reroot at the node (not on the incident edge)
Args:
``node`` (``Node``): The ``Node`` on whose incident edge this ``Tree`` will be rerooted
``length`` (``float``): The distance up the specified edge at which to reroot this ``Tree``. If 0 or ``None``, reroot at the node (not on the incident edge)
``branch_support`` (``bool``): ``True`` if internal node labels represent branch support values, otherwise ``False``
'''
warn("TreeSwift's rerooting functionality is poorly tested and likely has bugs. It will be fixed in a future release")
if not isinstance(node, Node):
raise TypeError("node must be a Node")
if length is not None and not isinstance(length, float) and not isinstance(length, int):
raise TypeError("length must be a float")
if not isinstance(branch_support, bool):
raise TypeError("branch_support must be a bool")
if length is not None and length < 0:
raise ValueError("Specified length at which to reroot must be positive")
if node.edge_length is None:
if length is not None and length != 0:
raise ValueError("Specified node has no edge length, so specified length must be None or 0")
elif length is not None and length > node.edge_length:
raise ValueError("Specified length must be shorter than the edge at which to reroot")
if length is not None and length > 0:
newnode = Node(edge_length=node.edge_length-length); node.edge_length = length
if not node.is_root():
p = node.parent; p.children.remove(node); p.add_child(newnode)
newnode.add_child(node); node = newnode
if node.is_root():
return
elif self.root.edge_length is not None:
newnode = Node(label='ROOT'); newnode.add_child(self.root); self.root = newnode
ancestors = [a for a in node.traverse_ancestors(include_self=True) if not a.is_root()]
for i in range(len(ancestors)-1, -1, -1):
curr = ancestors[i]; curr.parent.edge_length = curr.edge_length; curr.edge_length = None
if branch_support:
curr.parent.label = curr.label; curr.label = None
curr.parent.children.remove(curr); curr.add_child(curr.parent); curr.parent = None
self.root = node; self.is_rooted = True
[docs]
def resolve_polytomies(self):
'''Arbitrarily resolve polytomies with 0-lengthed edges.'''
self.root.resolve_polytomies()
[docs]
def sackin(self, normalize='leaves'):
'''Compute the Sackin balance index of this ``Tree``
Args:
``normalize`` (``str``): How to normalize the Sackin index (if at all)
* ``None`` to not normalize
* ``"leaves"`` to normalize by the number of leaves
* ``"yule"`` to normalize to the Yule model
* ``"pda"`` to normalize to the Proportional to Distinguishable Arrangements model
Returns:
``float``: Sackin index (either normalized or not)
'''
num_nodes_from_root = {}; sackin = 0; num_leaves = 0
for node in self.traverse_preorder():
num_nodes_from_root[node] = 1
if not node.is_root():
num_nodes_from_root[node] += num_nodes_from_root[node.parent]
if node.is_leaf():
num_nodes_from_root[node] -= 1; sackin += num_nodes_from_root[node]; num_leaves += 1
if normalize is None or normalize is False:
return sackin
elif not isinstance(normalize,str):
raise TypeError("normalize must be None or a string")
normalize = normalize.lower()
if normalize == 'leaves':
return float(sackin)/num_leaves
elif normalize == 'yule':
x = sum(1./i for i in range(2, num_leaves+1))
return (sackin - (2*num_leaves*x)) / num_leaves
elif normalize == 'pda':
return sackin/(num_leaves**1.5)
else:
raise RuntimeError("normalize must be None, 'leaves', 'yule', or 'pda'")
[docs]
def scale_edges(self, multiplier):
'''Multiply all edges in this ``Tree`` by ``multiplier``'''
if not isinstance(multiplier,int) and not isinstance(multiplier,float):
raise TypeError("multiplier must be an int or float")
for node in self.traverse_preorder():
if node.edge_length is not None:
node.edge_length *= multiplier
[docs]
def suppress_unifurcations(self):
'''Remove all nodes with only one child and directly attach child to parent'''
q = deque(); q.append(self.root)
while len(q) != 0:
node = q.popleft()
if len(node.children) != 1:
q.extend(node.children); continue
if (node.label is not None and node.label != '') or hasattr(node, 'node_params') or hasattr(node, 'edge_params'):
tmp_s = node.label
if tmp_s is None or tmp_s == '':
tmp_s = '%s %s' % (repr(node), str(node.__dict__))
warn("Deleting a node with label/attributes in suppress_unifurcations: %s" % tmp_s)
child = node.children.pop()
if node.is_root():
self.root = child; child.parent = None
else:
parent = node.parent; parent.remove_child(node); parent.add_child(child)
if node.edge_length is not None:
if child.edge_length is None:
child.edge_length = 0
child.edge_length += node.edge_length
if child.label is None and node.label is not None:
child.label = node.label
q.append(child)
[docs]
def traverse_inorder(self, leaves=True, internal=True):
'''Perform an inorder traversal of the ``Node`` objects in this ``Tree``
Args:
``leaves`` (``bool``): ``True`` to include leaves, otherwise ``False``
``internal`` (``bool``): ``True`` to include internal nodes, otherwise ``False``
'''
yield from self.root.traverse_inorder(leaves=leaves, internal=internal)
[docs]
def traverse_internal(self):
'''Traverse over the internal nodes of this ``Tree``'''
yield from self.root.traverse_internal()
[docs]
def traverse_leaves(self):
'''Traverse over the leaves of this ``Tree``'''
yield from self.root.traverse_leaves()
[docs]
def traverse_levelorder(self, leaves=True, internal=True):
'''Perform a levelorder traversal of the ``Node`` objects in this ``Tree``'''
yield from self.root.traverse_levelorder(leaves=leaves, internal=internal)
[docs]
def traverse_postorder(self, leaves=True, internal=True):
'''Perform a postorder traversal of the ``Node`` objects in this ``Tree``
Args:
``leaves`` (``bool``): ``True`` to include leaves, otherwise ``False``
``internal`` (``bool``): ``True`` to include internal nodes, otherwise ``False``
'''
yield from self.root.traverse_postorder(leaves=leaves, internal=internal)
[docs]
def traverse_preorder(self, leaves=True, internal=True):
'''Perform a preorder traversal of the ``Node`` objects in this ``Tree``
Args:
``leaves`` (``bool``): ``True`` to include leaves, otherwise ``False``
``internal`` (``bool``): ``True`` to include internal nodes, otherwise ``False``
'''
yield from self.root.traverse_preorder(leaves=leaves, internal=internal)
[docs]
def traverse_rootdistorder(self, ascending=True, leaves=True, internal=True):
'''Perform a traversal of the ``Node`` objects in this ``Tree`` in either ascending (``ascending=True``) or descending (``ascending=False``) order of distance from the root
Args:
``ascending`` (``bool``): ``True`` to perform traversal in ascending distance from the root, otherwise ``False`` for descending
``leaves`` (``bool``): ``True`` to include leaves, otherwise ``False``
``internal`` (``bool``): ``True`` to include internal nodes, otherwise ``False``
'''
yield from self.root.traverse_rootdistorder(ascending=ascending, leaves=leaves, internal=internal)
[docs]
def treeness(self):
'''Compute the `treeness` (sum of internal branch lengths / sum of all branch lengths) of this ``Tree``. Branch lengths of ``None`` are considered 0 length
Returns:
``float``: `Treeness` of this ``Tree`` (sum of internal branch lengths / sum of all branch lengths)
'''
internal = 0.; all = 0.
for node in self.traverse_preorder():
if node.edge_length is not None:
all += node.edge_length
if not node.is_leaf():
internal += node.edge_length
return internal/all
[docs]
def write_tree_nexus(self, filename):
'''Write this ``Tree`` to a Nexus file
Args:
``filename`` (``str``): Path to desired output file (plain-text or gzipped)
'''
if not isinstance(filename, str):
raise TypeError("filename must be a str")
treestr = self.newick()
leaf_labels = [node.label for node in self.traverse_leaves()]
if treestr.startswith('[&R]'):
treestr = treestr[4:].strip()
if filename.lower().endswith('.gz'): # gzipped file
f = gopen(expanduser(filename), 'wt', 9)
else: # plain-text file
f = open(expanduser(filename), 'w')
f.write('#NEXUS\n')
f.write('Begin Taxa;\n')
f.write(' Dimensions NTAX=%d;\n' % len(leaf_labels))
f.write(' TaxLabels %s;\n' % ' '.join(leaf_labels))
f.write('End;\n')
f.write('Begin Trees;\n')
f.write(' Tree tree1=%s\n' % treestr)
f.write('End;\n')
[docs]
def write_tree_newick(self, filename, hide_rooted_prefix=False):
'''Write this ``Tree`` to a Newick file
Args:
``filename`` (``str``): Path to desired output file (plain-text or gzipped)
``hide_rooted_prefix`` (``bool``): Hide the rooted prefix ``[&R]`` if rooted tree
'''
if not isinstance(filename, str):
raise TypeError("filename must be a str")
treestr = self.newick()
if hide_rooted_prefix:
if treestr.startswith('[&R]'):
treestr = treestr[4:].strip()
else:
warn("Specified hide_rooted_prefix, but tree was not rooted")
if filename.lower().endswith('.gz'): # gzipped file
f = gopen(expanduser(filename), 'wt', 9)
else: # plain-text file
f = open(expanduser(filename),'w')
f.write(treestr); f.close()
def plot_ltt(lineages, show_plot=True, export_filename=None, color='#000000', xmin=None, xmax=None, ymin=None, ymax=None, title=None, xlabel=None, ylabel=None):
'''Plot the Lineages Through Time (LTT) curve of a given tree
Args:
``lineages`` (``dict``): The ``lineages`` dictionary returned by a ``Tree`` object's ``lineages_through_time()`` function call
``show_plot`` (``bool``): ``True`` to show the plot, otherwise ``False`` to only return the dictionary. To plot multiple LTTs on the same figure, set ``show_plot`` to False for all but the last plot.
``export_filename`` (``str``): File to which the LTT figure will be exported (otherwise ``None`` to not save to file)
``color`` (``str``): The color of the resulting plot
``title`` (``str``): The title of the resulting plot
``xmin`` (``float``): The minimum value of the horizontal axis in the resulting plot
``xmax`` (``float``): The maximum value of the horizontal axis in the resulting plot
``xlabel`` (``str``): The label of the horizontal axis in the resulting plot
``ymin`` (``float``): The minimum value of the vertical axis in the resulting plot
``ymax`` (``float``): The maximum value of the vertical axis in the resulting plot
``ylabel`` (``str``): The label of the vertical axis in the resulting plot
'''
import matplotlib.pyplot as plt; from matplotlib.ticker import MaxNLocator
if 'TREESWIFT_FIGURE' not in globals():
global TREESWIFT_FIGURE; TREESWIFT_FIGURE = None
if TREESWIFT_FIGURE is None:
TREESWIFT_FIGURE = plt.figure()
TREESWIFT_FIGURE.gca().yaxis.set_major_locator(MaxNLocator(integer=True)) # integer y ticks
TREESWIFT_FIGURE.XMIN = float('inf'); TREESWIFT_FIGURE.XMAX = float('-inf')
TREESWIFT_FIGURE.YMIN = float('inf'); TREESWIFT_FIGURE.YMAX = float('-inf')
times = sorted(lineages.keys())
if times[0] < TREESWIFT_FIGURE.XMIN:
TREESWIFT_FIGURE.XMIN = times[0]
if times[-1] > TREESWIFT_FIGURE.XMAX:
TREESWIFT_FIGURE.XMAX = times[-1]
for i in range(len(times)-1):
if i == 0:
prev = 0
else:
prev = lineages[times[i-1]]
if lineages[times[i]] > TREESWIFT_FIGURE.YMAX:
TREESWIFT_FIGURE.YMAX = lineages[times[i]]
if lineages[times[i]] < TREESWIFT_FIGURE.YMIN:
TREESWIFT_FIGURE.YMIN = lineages[times[i]]
TREESWIFT_FIGURE.gca().plot([times[i],times[i]], [prev,lineages[times[i]]], color=color)
TREESWIFT_FIGURE.gca().plot([times[i],times[i+1]], [lineages[times[i]],lineages[times[i]]], color=color)
if len(times) > 1:
TREESWIFT_FIGURE.gca().plot([times[-1],times[-1]], [lineages[times[-2]],lineages[times[-1]]], color=color)
if lineages[times[-1]] < TREESWIFT_FIGURE.YMIN:
TREESWIFT_FIGURE.YMIN = lineages[times[-1]]
if show_plot or export_filename is not None:
if xmin is None:
xmin = TREESWIFT_FIGURE.XMIN
elif not isinstance(xmin,int) and not isinstance(xmin,float):
warn("xmin is invalid, so using the default"); xmin = TREESWIFT_FIGURE.XMIN
if xmax is None:
xmax = TREESWIFT_FIGURE.XMAX
elif not isinstance(xmax,int) and not isinstance(xmax,float):
warn("xmax is invalid, so using the default"); xmax = TREESWIFT_FIGURE.XMAX
plt.xlim(left=xmin, right=xmax)
if ymin is None:
ymin = TREESWIFT_FIGURE.YMIN
elif not isinstance(ymin,int) and not isinstance(ymin,float):
warn("ymin is invalid, so using the default"); ymin = TREESWIFT_FIGURE.YMIN
if ymax is None:
ymax = ceil(TREESWIFT_FIGURE.YMAX*1.1)
elif not isinstance(ymax,int) and not isinstance(ymax,float):
warn("ymax is invalid, so using the default"); ymax = ceil(TREESWIFT_FIGURE.YMAX*1.1)
plt.ylim(bottom=ymin, top=ymax)
if title is not None and not isinstance(title,str):
warn("title is invalid, so using the default"); title = None
if title is None:
plt.title("Lineages Through Time")
else:
plt.title(title)
if xlabel is not None and not isinstance(xlabel,str):
warn("xlabel is invalid, so using the default"); xlabel = None
if xlabel is None:
plt.xlabel("Time")
else:
plt.xlabel(xlabel)
if ylabel is not None and not isinstance(ylabel,str):
warn("ylabel is invalid, so using the default"); ylabel = None
if ylabel is None:
plt.ylabel("Number of Lineages")
else:
plt.ylabel(ylabel)
if show_plot:
plt.show()
if export_filename is not None:
TREESWIFT_FIGURE.savefig(export_filename)
TREESWIFT_FIGURE = None
[docs]
def read_tree_dendropy(tree):
'''Create a TreeSwift tree from a DendroPy tree
Args:
``tree`` (``dendropy.datamodel.treemodel``): A Dendropy ``Tree`` object
Returns:
``Tree``: A TreeSwift tree created from ``tree``
'''
out = Tree(); d2t = {}
if not hasattr(tree, 'preorder_node_iter') or not hasattr(tree, 'seed_node') or not hasattr(tree, 'is_rooted'):
raise TypeError("tree must be a DendroPy Tree object")
if tree.is_rooted != True:
out.is_rooted = False
for node in tree.preorder_node_iter():
if node == tree.seed_node:
curr = out.root
else:
curr = Node(); d2t[node.parent_node].add_child(curr)
d2t[node] = curr; curr.edge_length = node.edge_length
if hasattr(node, 'taxon') and node.taxon is not None:
curr.label = node.taxon.label
else:
curr.label = node.label
return out
[docs]
def read_tree_newick(newick):
'''Read a tree from a Newick string or file
Args:
``newick`` (``str``): Either a Newick string or the path to a Newick file (plain-text or gzipped)
Returns:
``Tree``: The tree represented by ``newick``. If the Newick file has multiple trees (one per line), a ``list`` of ``Tree`` objects will be returned
'''
if not isinstance(newick, str):
try:
newick = str(newick)
except:
raise TypeError("newick must be a str")
if len(newick) < 1000 and isfile(expanduser(newick)):
if newick.lower().endswith('.gz'): # gzipped file
f = gopen(expanduser(newick), 'rt')
else: # plain-text file
f = open(expanduser(newick))
ts = f.read().strip(); f.close()
else: # string
ts = newick.strip()
lines = [l.strip() for l in ts.splitlines()]
if len(lines) != 1:
is_multi_newick = True
for l in lines:
if not l.endswith(';'):
is_multi_newick = False
if is_multi_newick:
return [read_tree_newick(l) for l in lines]
else:
ts = ts.replace('\n','')
try:
t = Tree(); t.is_rooted = ts.startswith('[&R]')
if ts[0] == '[':
ts = ']'.join(ts.split(']')[1:]).strip(); ts = ts.replace(', ',',')
n = t.root; i = 0; parse_length = False; parse_label = False
while i < len(ts):
# end of Newick string
if not parse_label and ts[i] == ';':
if i != len(ts)-1 or n != t.root:
raise RuntimeError(INVALID_NEWICK)
# go to new child
elif not parse_label and ts[i] == '(':
c = Node(); n.add_child(c); n = c
# go to parent
elif not parse_label and ts[i] == ')':
n = n.parent
# go to new sibling
elif not parse_label and ts[i] == ',':
n = n.parent; c = Node(); n.add_child(c); n = c
while ts[i+1] == ' ':
i += 1 # skip spaces after commas
# comment (square brackets)
elif not parse_label and ts[i] == '[':
count = 0; start_ind = i
while True:
if ts[i] == '[':
count += 1
elif ts[i] == ']':
count -= 1
if count == 0:
break
i += 1
# store comment as node_params or edge_params
curr_comment = ts[start_ind+1 : i] # don't include first and last [ and ]
if parse_length:
n.edge_params = curr_comment
else:
n.node_params = curr_comment
# edge length
elif not parse_label and ts[i] == ':':
parse_length = True
elif parse_length:
ls = ''
while ts[i] not in {',', ')', ';', '['}:
ls += ts[i]; i += 1
n.edge_length = float(ls); i -= 1; parse_length = False
# node label
elif not parse_label and ts[i] == "'":
parse_label = True
else:
label = ''
while parse_label or ts[i] not in {':', ',', ';', ')', '['}:
if ts[i] == "'":
parse_label = not parse_label
else:
label += ts[i]
i += 1
i -= 1; n.label = label; parse_label = False
i += 1
except Exception as e:
raise RuntimeError(f"Failed to parse string as Newick: {ts}")
return t
[docs]
def read_tree_nexml(nexml):
'''Read a tree from a NeXML string or file
Args:
``nexml`` (``str``): Either a NeXML string or the path to a NeXML file (plain-text or gzipped)
Returns:
``dict`` of ``Tree``: A dictionary of the trees represented by ``nexml``, where keys are tree names (``str``) and values are ``Tree`` objects
'''
if not isinstance(nexml, str):
raise TypeError("nexml must be a str")
if nexml.lower().endswith('.gz'): # gzipped file
f = gopen(expanduser(nexml), 'rt')
elif isfile(expanduser(nexml)): # plain-text file
f = open(expanduser(nexml))
else:
f = nexml.splitlines()
trees = {}; id_to_node = {}; tree_id = None
for line in f:
l = line.strip(); l_lower = l.lower()
# start of tree
if l_lower.startswith('<tree '):
if tree_id is not None:
raise ValueError(INVALID_NEXML)
parts = l.split()
for part in parts:
if '=' in part:
k,v = part.split('='); k = k.strip()
if k.lower() == 'id':
tree_id = v.split('"')[1]; break
if tree_id is None:
raise ValueError(INVALID_NEXML)
trees[tree_id] = Tree(); trees[tree_id].root = None
# end of tree
elif l_lower.replace(' ','').startswith('</tree>'):
if tree_id is None:
raise ValueError(INVALID_NEXML)
id_to_node = {}; tree_id = None
# node
elif l_lower.startswith('<node '):
if tree_id is None:
raise ValueError(INVALID_NEXML)
node_id = None; node_label = None; is_root = False
k = ''; v = ''; in_key = True; in_quote = False
for i in range(6, len(l)):
if l[i] in {'"', "'"}:
in_quote = not in_quote
if not in_quote and in_key and l[i] == '=':
in_key = False
elif not in_quote and not in_key and l[i] in {'"', "'"}:
k = k.strip()
if k.lower() == 'id':
node_id = v
elif k.lower() == 'label':
node_label = v
elif k.lower() == 'root' and v.strip().lower() == 'true':
is_root = True
in_key = True; k = ''; v = ''
elif in_key and l[i] not in {'"', "'"}:
k += l[i]
elif not in_key and l[i] not in {'"', "'"}:
v += l[i]
if node_id is None or node_id in id_to_node:
raise ValueError(INVALID_NEXML)
id_to_node[node_id] = Node(label=node_label)
if is_root:
if trees[tree_id].root is not None:
raise ValueError(INVALID_NEXML)
trees[tree_id].root = id_to_node[node_id]
# edge
elif l_lower.startswith('<edge '):
if tree_id is None:
raise ValueError(INVALID_NEXML)
source = None; target = None; length = None
parts = l.split()
for part in parts:
if '=' in part:
k,v = part.split('='); k = k.strip(); k_lower = k.lower()
if k_lower == 'source':
source = v.split('"')[1]
elif k_lower == 'target':
target = v.split('"')[1]
elif k_lower == 'length':
length = float(v.split('"')[1])
if source is None or target is None or length is None:
raise ValueError(INVALID_NEXML)
if source not in id_to_node:
raise ValueError(INVALID_NEXML)
if target not in id_to_node:
raise ValueError(INVALID_NEXML)
id_to_node[source].add_child(id_to_node[target])
id_to_node[target].edge_length = length
elif l_lower.startswith('<rootedge '):
if tree_id is None:
raise ValueError(INVALID_NEXML)
root_node = None; length = None
parts = l.split()
for part in parts:
if '=' in part:
k,v = part.split('='); k = k.strip(); k_lower = k.lower()
if k_lower == 'target':
root_node = id_to_node[v.split('"')[1]]
elif k_lower == 'length':
length = float(v.split('"')[1])
if trees[tree_id].root is None:
raise ValueError(INVALID_NEXML)
if root_node is not None and trees[tree_id].root != root_node:
raise ValueError(INVALID_NEXML)
trees[tree_id].root.edge_length = length
if hasattr(f, 'close'):
f.close()
return trees
[docs]
def read_tree_nexus(nexus, translate=True):
'''Read a tree from a Nexus string or file
Args:
``nexus`` (``str``): Either a Nexus string or the path to a Nexus file (plain-text or gzipped)
``translate`` (``bool``): Translate the node labels on the trees (if the Nexus file has a "Translate" section)
Returns:
``dict`` of ``Tree``: A dictionary of the trees represented by ``nexus``, where keys are tree names (``str``) and values are ``Tree`` objects
If the Nexus file had a "Taxlabels" section, the taxon labels will be stored in the output dictionary as a ``list`` associated with key ``"taxlabels"``
If any trees in the Nexus file had information (e.g. the first ``...`` in the ``tree STATE_0 [...] = [&R] (...);`` line), all information will be stored in the output dictionary: they will be associated with key ``"info"`` in the output dictionary, and they will be stored in a dictionary where keys are tree names (``str``)
If ``translate`` was ``True``, if a node ``x`` was translated, ``x.label`` will be the translated label, and ``x.id`` will be the original label (the "ID")
'''
if not isinstance(nexus, str):
raise TypeError("nexus must be a str")
if nexus.lower().endswith('.gz'): # gzipped file
f = gopen(expanduser(nexus), 'rt')
elif isfile(expanduser(nexus)): # plain-text file
f = open(expanduser(nexus))
else:
f = nexus.splitlines()
trees = {}; taxlabels = None; tr = None; reading_taxlabels = False; reading_translate = False
for line in f:
l = line.strip()
# read line in TAXLABELS section (`TAXON_LABEL` on each line)
if reading_taxlabels:
if l == ';':
reading_taxlabels = False; trees['taxlabels'] = taxlabels
else:
taxlabels.append(l)
# read line in TRANSLATE section (`<whitespace>TREE_LABEL<whitespace>TAXON_LABEL,` on each line)
elif reading_translate:
if l == ';':
reading_translate = False; trees['translate'] = tr
else:
parts = l.split(); tmp = ' '.join(parts[1:])
if tmp.endswith(','):
tmp = tmp[:-1]
tr[parts[0]] = tmp
elif l.lower().startswith('tree '):
# find the `=` that separates the name (and possibly optional info `[...]`) from the actual Newick tree
i = l.index('=')
try:
br = l.index(']'); p = l.index('(')
if i < br < p:
i = l.index('=', br)
except:
pass
# split tree line and parse each side of `=`
i = l.rindex('=', 0, l.index('(')) # find last `=` before first `(` (i.e., before Newick string)
left = l[:i].strip(); right = l[i+1:].strip()
if '[' in left:
name = ' '.join(left.split('[')[0].split(' ')[1:]).strip()
if 'info' not in trees:
trees['info'] = {}
trees['info'][name] = left.split('[')[1].split(']')[0].strip()
else:
name = ' '.join(left.split(' ')[1:])
curr_tree = read_tree_newick(right)
if translate and tr is not None:
for node in curr_tree.traverse_preorder():
if node.label in tr:
node.id = node.label; node.label = tr[node.id]
trees[name] = curr_tree
elif l.lower() == 'taxlabels':
taxlabels = []; reading_taxlabels = True
elif l.lower() == 'translate':
tr = {}; reading_translate = True
if hasattr(f,'close'):
f.close()
if len(trees) == 0:
raise ValueError(INVALID_NEXUS)
return trees
def read_tree_linkage(linkage, return_list=False):
'''Read a tree from linkage matrix as specified in SciPy documentation. Code largely copied from scipy's to_tree() function
Args:
``linkage`` (``numpy.ndarray``): NumPy array representing linkage
* https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html
Returns:
``Tree`` representation of supplied linkage
* If ``return_list`` is ``True``, also returns a ``list`` ``nd`` such that ``nd[i]`` corresponds to the ``Node`` with id ``i``
'''
# check valid input
from numpy import ndarray
if not isinstance(linkage, ndarray):
raise TypeError("root must be a np.ndarray")
from scipy.cluster.hierarchy import is_valid_linkage
is_valid_linkage(linkage, throw=True, name='linkage')
# prepare
n = linkage.shape[0] + 1
d = [Node(i) for i in range(n)] + [None]*(n-1)
# process
nd = None
for i in range(n-1):
# check for validity
fi = int(linkage[i,0]); fj = int(linkage[i,1])
if fi > i + n:
raise ValueError('Corrupt matrix Z. Index to derivative cluster is used before it is formed. See row %d, column 0' % fi)
if fj > i + n:
raise ValueError('Corrupt matrix Z. Index to derivative cluster is used before it is formed. See row %d, column 1' % fj)
nd = Node(i+n, 1)
nd.add_child(d[fi]); nd.add_child(d[fj])
d[n + i] = nd
# finalize
out = Tree(); out.root = nd
if return_list:
return out, d
else:
return out
[docs]
def read_tree(input, schema):
'''Read a tree from a string or file
Args:
``input`` (``str``): Either a tree string, a path to a tree file (plain-text or gzipped), or a DendroPy Tree object
``schema`` (``str``): The schema of ``input`` (DendroPy, Newick, NeXML, Nexus, or linkage)
Returns:
* If the input is Newick, either a ``Tree`` object if ``input`` contains a single tree, or a ``list`` of ``Tree`` objects if ``input`` contains multiple trees (one per line)
* If the input is NeXML or Nexus, a ``dict`` of trees represented by ``input``, where keys are tree names (``str``) and values are ``Tree`` objects
'''
schema_to_function = {
'dendropy': read_tree_dendropy,
'newick': read_tree_newick,
'nexml': read_tree_nexml,
'nexus': read_tree_nexus,
'linkage': read_tree_linkage
}
if schema.lower() not in schema_to_function:
raise ValueError(f"Invalid schema: {schema} (valid options: {', '.join(sorted(schema_to_function.keys()))})")
return schema_to_function[schema.lower()](input)