#!/usr/bin/python

import re
import math

def loadTree(filename):
#This function reads a file in Newick format and returns our simple
#dictionary-based data structure for trees.
#Uses parseTree() to interpret input string.
	f = open(filename,'r')
	exp = f.read()
	f.close

	exp = exp.replace(';','') #ignore trailing (or other) semi-colons
	exp = re.sub(r'\s+','',exp) #ignore whitespace
	exp = re.sub(r'\n','',exp)
	exp = re.sub(r'\[.*\]','',exp) #ignore bracketed clauses

	return parseTree(exp)


def makeLeaf(name,length):
#This function returns a tree structure corresponding to a single leaf
	return { 'left':None, 'right':None, 'name':name, 'length':length }


def parseTree(exp):
#This function takes a string in Newick format and parses it recursively.
#Each clause is expected to be of the general form (a:x,b:y):z
#where a and b may be subtrees in the same format.

	if ',' not in exp: #if this is a leaf
		name, length = exp.split(':')
		length = float(length)
		return makeLeaf(name,length)

	#uses the regular expression features of Python
	distPattern = re.compile(r'(?P<tree>\(.+\))\:(?P<length>[e\-\d\.]+)$')
	m = distPattern.search(exp)
	length = 0
	if m:			
		if m.group('length'): length = float( m.group('length') )
		exp = m.group('tree')
	if length == '': length = 0

	#Use the parseExp function to return the left and right hand sides
	#of the expression (e.g., a & b from (a,b))
	lhs, rhs = parseExp(exp)

	#Now package into a tree data structure
	return { "name":"internal",
			 "left":parseTree(lhs), #recursively set up subtrees
			 "right":parseTree(rhs),
			 "length":length }


def parseExp(exp):
	#Parse expression of type "a,b" into a & b where a and b can be
	#Newick formatted strings.
	chars = list(exp[1:-1]) #get rid of surrounding parens, and split to list
	count = 0
	lhs = True #boolean to distinguish left and right side of the comma
	left = '' #store the left substring
	right = '' #store the right substring

	#a little tricky to deal with nested parens correctly
	for c in chars:
		if c == '(':
			count += 1
		if c == ')':
			count -= 1
		if (c == ',') and (count == 0) and (lhs) :
			lhs = False
			continue

		if lhs: left += c
		else: right += c

	#Now return the left and right substrings
	return [ left, right ]


def readAlignment(filename):
#read an alignment in Phylip (sort of) format
#and return a dictionary of sequences
	f = open(filename,"r")
	taxa = None
	columns = None
	sequences = {}
	
	for line in f:
		if taxa == None:
			#first line tells how many seqs and cols
			taxa, columns = line.split()
		else:
			words = line.split() #lines can have whitespace
			name = words[0] #we'll require to start with taxon name
			seq = ''.join(words[1:])
			if name in sequences.keys():
				sequences[name] += seq
			else:
				sequences[name] = seq

	return sequences


def initTree(tree,aln):
#insert the sequences from an alignment into the 'data' field of tree
#stores a list of chars at each node, instead of a string
#(only changes leaves)
	if (tree['name'] != 'internal'):
		chars = aln[tree['name']]
		tree['data'] = [ [chars[x]] for x in range(0,len(chars)) ]
		return
	initTree(tree['left'],aln)
	initTree(tree['right'],aln)

#this function returns the number of mutations necessary
#to explain the sequence data, given the tree topology
#and assigns sequences to the internal nodes

def downPass(tree):

	# stop conditions
	if tree['name'] is not "internal":
		return 0

	leftCost = downPass(tree['left'])
	rightCost = downPass(tree['right'])

	# the children's sequences
	leftSeq = tree['left']['data']
	rightSeq = tree['right']['data']

	# initiate current node's sequences
	tree['data'] = []
	mutations = 0

	# iterate over all nodes in sequence
	for i in range(len(leftSeq)):

		I = Intersect(leftSeq[i],rightSeq[i])
		U = Union(leftSeq[i],rightSeq[i])

		# if intersection not empty, return; else, call
		# mutation and return union.
		if len(I) > 0:
			tree['data'].append(I)
		else:
			tree['data'].append(U)
			mutations += 1

	return mutations + leftCost + rightCost
	
# homebrewed intersection and union functions
def Intersect(list1,list2):

	list3 = []
	for i in list1:
		if i in list2:
			list3.append(i)
	return list3

def Union(list1,list2):

	list3 = []
	for i in list1:
		list3.append(i)
	for i in list2:
		if i not in list3:
			list3.append(i)
	return list3

tree = loadTree('Tree4.txt')
aln = readAlignment('seqs.aln')

initTree(tree,aln)

print 'Number of mutations: %d' % (downPass(tree))
print 'Sequence at root: %s' % ','.join( [ '/'.join(x) for x in tree['data'] ] )
