#!/usr/bin/python

import re
import math

def loadTree(filename):
#This function reads a file in Newick format and returns our simple
#dictionary-based data structure for trees.
#Uses parseTree() to interpret input string.
	f = open(filename,'r')
	exp = f.read()
	f.close

	exp = exp.replace(';','') #ignore trailing (or other) semi-colons
	exp = re.sub(r'\s+','',exp) #ignore whitespace
	exp = re.sub(r'\n','',exp)
	exp = re.sub(r'\[.*\]','',exp) #ignore bracketed clauses

	return parseTree(exp)


def makeLeaf(name,length):
#This function returns a tree structure corresponding to a single leaf
	return { 'left':None, 'right':None, 'name':name, 'length':length }


def parseTree(exp):
#This function takes a string in Newick format and parses it recursively.
#Each clause is expected to be of the general form (a:x,b:y):z
#where a and b may be subtrees in the same format.

	if ',' not in exp: #if this is a leaf
		name, length = exp.split(':')
		length = float(length)
		return makeLeaf(name,length)

	#uses the regular expression features of Python
	distPattern = re.compile(r'(?P<tree>\(.+\))\:(?P<length>[e\-\d\.]+)$')
	m = distPattern.search(exp)
	length = 0
	if m:			
		if m.group('length'): length = float( m.group('length') )
		exp = m.group('tree')
	if length == '': length = 0

	#Use the parseExp function to return the left and right hand sides
	#of the expression (e.g., a & b from (a,b))
	lhs, rhs = parseExp(exp)

	#Now package into a tree data structure
	return { "name":"internal",
			 "left":parseTree(lhs), #recursively set up subtrees
			 "right":parseTree(rhs),
			 "length":length }


def parseExp(exp):
	#Parse expression of type "a,b" into a & b where a and b can be
	#Newick formatted strings.
	chars = list(exp[1:-1]) #get rid of surrounding parens, and split to list
	count = 0
	lhs = True #boolean to distinguish left and right side of the comma
	left = '' #store the left substring
	right = '' #store the right substring

	#a little tricky to deal with nested parens correctly
	for c in chars:
		if c == '(':
			count += 1
		if c == ')':
			count -= 1
		if (c == ',') and (count == 0) and (lhs) :
			lhs = False
			continue

		if lhs: left += c
		else: right += c

	#Now return the left and right substrings
	return [ left, right ]


def readAlignment(filename):
#read an alignment in Phylip (sort of) format
#and return a dictionary of sequences
	f = open(filename,"r")
	taxa = None
	columns = None
	sequences = {}
	
	for line in f:
		if taxa == None:
			#first line tells how many seqs and cols
			taxa, columns = line.split()
		else:
			words = line.split() #lines can have whitespace
			name = words[0] #we'll require to start with taxon name
			seq = ''.join(words[1:])
			if name in sequences.keys():
				sequences[name] += seq
			else:
				sequences[name] = seq

	return sequences


def initTree(tree,aln):
#insert the sequences from an alignment into the 'data' field of tree
#stores a list of chars at each node, instead of a string
#(only changes leaves)
	if (tree['name'] != 'internal'):
		chars = aln[tree['name']]
		tree['data'] = [ [chars[x]] for x in range(0,len(chars)) ]
		return len(chars)
	initTree(tree['left'],aln)
	length = initTree(tree['right'],aln)
	return length


def evoModel(x,y,distance):
#this function returns the probability of a change in sequence from
#x->y given an evolutionary distance
	import math #the math.exp() function may be useful
#
#...insert code here...
#


def ml(tree,pos):
#this function returns a dictionary containing the likelihood
#of each of the characters ['A','C','G','T']

	if tree['name'] != 'internal':
		likelihood = {}
		for n in ['A','C','G','T']:
			if [n] == tree['data'][pos]:
				likelihood[n] = 1
			else:
				likelihood[n] = 0
		return likelihood
#
#...insert code here...
#

def maxItem(x):
#return the dictionary entry with the max value
	max = None
	for k,v in x.iteritems():
		if v > max:
			max,maxKey = v,k
	return max,maxKey




tree = loadTree('tree5.txt')
aln = readAlignment('seqs5.aln')
alnLength = initTree(tree,aln)

print "(Likelihood of site, Best sequence at root)"
for pos in range(alnLength):
        print maxItem(ml(tree,pos))

