Skip to content
Snippets Groups Projects
Commit 6a334ff0 authored by Menardo Fabrizio's avatar Menardo Fabrizio
Browse files

add Treemmer.py

parents
Branches
No related tags found
No related merge requests found
#Treemmer
#Author: Fabrizio Menardo
from joblib import Parallel, delayed
from ete3 import Tree
import sys
import random
import operator
import argparse
############################################################ define arg type float 0 < X > 1 ###############################################################
def restricted_float(x):
x = float(x)
if x < 0.0 or x > 1.0:
raise argparse.ArgumentTypeError("%r not in range [0.0, 1.0]"%(x,))
return x
########################################## FIND LEAVES NEIGHBORS OF A LEAF (2 NODE OF DISTANCE MAX) and calc DISTANCE #######################
def find_N(t,leaf):
dlist ={}
parent= leaf.up
dist_parent=leaf.dist
flag=0
if arguments.verbose==3:
print "leaf findN at ieration: " + str(counter)
print leaf
print "parent findN at ieration: " + str(counter)
print parent
print parent.get_children()
sister_flag=0
for n in range(0,len(parent.get_children())): ##this for loop start from parent and climb up max two nodes, if it finds leaves calculate the distances,
if parent.is_root():
flag=1
break
if arguments.verbose==3:
print "children " + str(n)
print parent.children[n]
if (parent.children[n].is_leaf()): # search at one node of distance
if (parent.children[n] != leaf):
DIS = leaf.get_distance(parent.children[n])
dlist.update({leaf.name + "," +parent.children[n].name : DIS})
flag=flag+1
if arguments.verbose==3:
print leaf.name + "," +parent.children[n].name + str(DIS) + "have one node of distance"
else:
if flag == 0:
if arguments.verbose==3: #going up, search at two nodes of distance
print "going up, brother is node"
temp_dlist={}
for nn in range(0,len(parent.children[n].get_children())):
if (parent.children[n].children[nn].is_leaf()):
DIS = leaf.get_distance(parent.children[n].children[nn])
temp_dlist.update({leaf.name + "," +parent.children[n].children[nn].name : DIS})
sister_flag=sister_flag +1
if ((sister_flag==1) and (flag==0)): #collect results at two nodes of distance only if there are no leaves that are closer
dlist.update(temp_dlist)
if arguments.verbose==3:
print str(temp_dlist) + " are not sister taxa, but neighbours first is leaf, second is upper neighbor"
if (flag == 0): #### this means that the leaf has not neighbors at one node of dist
parent=parent.up #### therefore I climb the tree down towards the root of one more step and look for leaves
multi_flag=0
if arguments.verbose==3:
print "going down"
print "gran parent"
print parent
temp_dlist={}
for n in range(0,len(parent.get_children())): #this for loop start from grean parent and climb up max one nodes, if it finds leaves calculate the distances,
if parent.is_root():
break
if (parent.children[n].is_leaf()):
DIS = leaf.get_distance(parent.children[n])
multi_flag = multi_flag+1
temp_dlist.update({leaf.name + "," +parent.children[n].name : DIS})
if multi_flag==1: # this is to deal with polytomies
dlist.update(temp_dlist)
if arguments.verbose==3:
print leaf.name + "," +parent.children[n].name + str(DIS) + " are not sister taxa, but neighbours first is leaf, second is neighbor of downstair (towards root)"
return dlist
########################################## IDENTIFY LEAF TO PRUNE #######################
def find_leaf_to_prune(dlist): #parse the list with all neighbor pairs and distances, find the closest pair and select the leaf
min_val = min(dlist.itervalues())
d_min={}
for k, v in dlist.iteritems():
if v == min_val:
d_min.update({k:v})
pair= str(random.choice(list(d_min)))
pair=pair.split(",")
leaf1 = t.search_nodes(name=pair[0])[0]
leaf2 = t.search_nodes(name=pair[1])[0]
if (leaf1.dist > leaf2.dist):
if (arguments.leaves_pair == 1):
leaf_to_prune = leaf2.name
dist = leaf2.dist
if (arguments.leaves_pair == 0):
leaf_to_prune = leaf1.name
dist = leaf1.dist
if (leaf1.dist < leaf2.dist):
if (arguments.leaves_pair == 1):
leaf_to_prune = leaf1.name
dist = leaf1.dist
if (arguments.leaves_pair == 0):
leaf_to_prune = leaf2.name
dist = leaf2.dist
if ((leaf1.dist == leaf2.dist) or (arguments.leaves_pair ==2)):
leaf_to_prune = random.choice(list(pair)) #this select the leaf at random within the couple
dist = leaf1.dist
return (leaf_to_prune,dist)
########################################## PRUNE LEAF FROM TREE #######################
def prune_t(leaf_to_prune,tree):
G = tree.search_nodes(name=leaf_to_prune)[0]
parent= G.up
dist_parent=G.dist
if (len(parent.get_children()) == 2):
if parent.children[0] != G:
parent.children[0].dist = parent.children[0].dist + parent.dist
if parent.children[1] != G:
parent.children[1].dist = parent.children[1].dist + parent.dist
G.detach()
if (len(parent.get_children()) == 1):
parent.delete() # after pruning the remaining branch will be like this ---/---leaf_name. I delete useless node keeping the b length
return tree
#################################################################### calculate Tree length ##########################################################3
def calculate_TL(t):
tree_length=0
for n in t.traverse():
tree_length=tree_length+n.dist
tot_TL = tree_length
return(tot_TL)
########################################## PRUNE LEAF FROM MATRIX #######################
def prune_dist_matrix(dlist,leaf_to_prune):
key_del=[]
for k, v in dlist.iteritems():
(one,two)=k.split(",")
if ((one == leaf_to_prune) or (two == leaf_to_prune)):
key_del.append(k)
for KK in key_del:
del dlist[KK]
return dlist
########################################## parallel loop #######################
def parallel_loop(i):
n=i
while n < len(leaves):
N_list=find_N(t,leaves[n])
n=n+arguments.cpu #n of threads
if N_list:
DLIST.update(N_list)
return DLIST
########################################## write output with stop option #######################
def write_stop(t,output1,output2):
F=open(output1,"w")
F.write(t.write())
F.close()
leaves = t.get_leaves()
list_names=[]
for leaf in leaves:
list_names.append(leaf.name)
F=open(output2,"w")
F.write("\n".join(list_names))
F.close()
###### SOFTWARE START
parser = argparse.ArgumentParser()
parser.add_argument('INFILE',type=str,help='path to the newick tree')
parser.add_argument('-r','--resolution', metavar='INT', default=1,help='number of leaves top prune at each iteration (default: 1)',type =int, nargs='?')
parser.add_argument('-c','--cpu', metavar='INT', default=1,help='number of cpu to use (default: 1)',type =int, nargs='?')
parser.add_argument('-v' ,'--verbose', metavar='0,1,2', default='0', help='0: silent, 1: show progress, 2: print tree at each iteration, 3: only for testing (findN), 4: only for testing (prune_t) (default: 1)', type =int, nargs='?',choices=[0,1,2,3,4])
parser.add_argument('-p','--solve_polytomies',help='resolve polytmies at random (default: FALSE)',action='store_true',default =False)
### yet to implemen
parser.add_argument('-X','--stop_at_X_leaves', metavar='0-n_leaves', default='0', help='stop pruning when the number of leaves fall below X (integer)', type =int, nargs='?')
parser.add_argument('-RTL','--stop_at_RTL', metavar='0-1', default='0', help='stop pruning when the relative tree length falls below RTL (decimal number between 0 and 1)', type =restricted_float,nargs='?')
parser.add_argument('-lp','--leaves_pair', metavar='0,1,2', default=2,help='After the pair of leaves with the smallest distance is dentified Treemmer prunes: 0: the longest leaf\n1: the shortest leaf\n2: random choice (default)',type =int, nargs='?')
arguments = parser.parse_args()
if ((arguments.stop_at_RTL > 0) and (arguments.stop_at_X_leaves > 0)):
raise argparse.ArgumentTypeError("-X and -RTL are mutually exclusive arguments")
t = Tree(arguments.INFILE,format=1)
counter =0
output=[]
stop=0
TOT_TL=calculate_TL(t)
TL=TOT_TL
if arguments.solve_polytomies:
t.resolve_polytomy()
if arguments.verbose > 0:
print "N of taxa in tree is : "+ str(len(t))
if arguments.solve_polytomies:
print "\nPolytomies will be solved at random"
else:
print "\nPolytomies will be kept"
if arguments.stop_at_X_leaves:
print "\nTreemmer will reduce the tree to" + str(arguments.stop_at_X_leaves) + " leaves"
else:
if arguments.stop_at_RTL:
print "\nTreemmer will reduce the tree to" + str(arguments.stop_at_RTL) + " of the original tree length"
else:
print "\nTreemmer will calculate the tree length decay"
print "\nTreemmer will prune " + str(arguments.resolution) + " leaves at each iteration"
print "\nTreemmer will use " + str(arguments.cpu) + " cpu(s)"
while (len(t) > 3):
counter = counter +1
leaves = t.get_leaves()
DLIST={}
if arguments.verbose > 0:
print "\niter " + str(counter)
if arguments.verbose > 1:
print "calculationg distances"
DLIST = Parallel(n_jobs=arguments.cpu)(delayed(parallel_loop)(i) for i in range(0,arguments.cpu))
result = {}
for d in DLIST: #when running in parallel DLIST is updated in a weird way, it is a dict of dicts, this for loop merge them all in one
result.update(d)
DLIST=result
if arguments.verbose > 1:
print DLIST
print "\npruning big deal\n"
for r in range (1,arguments.resolution+1):
if (len(DLIST)==0):
break
(leaf_to_p,dist) = find_leaf_to_prune(DLIST)
leaf_to_prune = t.search_nodes(name=leaf_to_p)[0]
t = prune_t(leaf_to_p,t)
TL=TL-dist
DLIST=prune_dist_matrix(DLIST,leaf_to_p)
rel_TL=TL/TOT_TL
if arguments.stop_at_X_leaves:
if arguments.stop_at_X_leaves >= len(t):
output1=arguments.INFILE+"_trimmed_tree_X_" + str(arguments.stop_at_X_leaves)
output2=arguments.INFILE+"_trimmed_list_X_" + str(arguments.stop_at_X_leaves)
write_stop(t,output1,output2)
stop=1
break
if arguments.stop_at_RTL:
if arguments.stop_at_RTL >= rel_TL:
output1=arguments.INFILE+"_trimmed_tree_RTL_" + str(arguments.stop_at_RTL)
output2=arguments.INFILE+"_trimmed_list_RTL_" + str(arguments.stop_at_RTL)
write_stop(t,output1,output2)
stop=1
break
if arguments.verbose > 1:
print "\n ITERATION RESOLUTION: " + str(r)
print "leaf to prune:\n" + str(leaf_to_p) + " " + str(dist)
print "\n new tree"
print t
print "\nRTL : " + str(rel_TL) + " N_seq: " +str(len(t))
print "\nnew matrix\n"
print DLIST
if (stop ==1):
break
output.append (str(rel_TL) + ' ' + str(len(t)))
if arguments.verbose==1:
print "\nRTL : " + str(rel_TL) + " N_seq: " +str(len(t))
if stop == 0:
F=open(arguments.INFILE+"_res_"+ str(arguments.resolution) + "_LD","w")
F.write("\n".join(output))
########################################### LOOP THRU ALL LEAVES (replaced by joblib parallel line #############################################
# for leaf in leaves:
# counter1 = counter1 + 1
# print "iter find N : "+ str(counter1)
# N_list=find_N(t,leaf)
# print N_list
# if N_list:
# DLIST.update(N_list)
########################################### CALCULATE MATRIX DISTANCE into simple DICTIONARY very slow!! dendropy is orders of magnitude faster #############################################
#def make_dist_matrix(leaves):
# count=0
# dlist = {}
# for x in range (0,len(leaves)-1):
# count=count+1
#
# for y in range (x+1,len(leaves)):
#
# DIS = leaves[x].get_distance(leaves[y])
# dlist.update({leaves[x].name + "," +leaves[y].name : DIS})
#
# return dlist
#print dlist
########################################### CALCULATE MATRIX DISTANCE into nested DICTIONARY ##############################################
#m = {}
#for x in range (0,len(leaves)-2):
# d = {}
# for y in range (x+1,len(leaves)-1):
#
# DIS = leaves[x].get_distance(leaves[y])
# #print str(leaves[x]) + str(leaves[y]) + "\n" + str(DIS)
# d.update({leaves[y].name : DIS})
#print "D"#
#print d
# m.update({leaves[x].name : d})
#print "M"
#print m
#print m
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment