import naivebayes as nb
import time,os,json
import pretestprocessing as ptp

###################################################################################
## TEST A NAIVE BAYESIAN LEARNER (AND SAVE FALSE POSITIVES AND CLOSE RELATIVES)  ##
## This code implements testing for the Naive Bayesian learner defined in the    ##
## ...separate file naivebayes.py.                                               ## 
## Note that many of the functions used below are defined in the file            ##
## ... pretetprocessing.py, imported above                                       ##
## We also save the trial ids of correctly classified  trials as well as of      ##
## ... false positives and of trials almost classified as the target category    ##
## ... in their respective text files so we can examine the trial texts of those.##
## Note that there are no complex imports here, so can use pypy, which           ##
## ... cuts run time in under half.                                              ##
###################################################################################



# Preliminaries: These need to be defined.
categoriesdir = '../baileyfiles/1830s-trialsbycategory/'
sampledirname = '../baileyfiles/Samples_1830s/' #location of 10-fold cross-validation
stopwordfilename = 'english-stopwords.txt'
indirname = '../baileyfiles/' #for reading in general files like trialdict
# the ones below should be set to None if not using
# if cattocheck is None, the classifier does multiway classification
# remember that pattern must find the kind of categories you are checking for! 
cattocheck = 'breakingpeace' #if evaluating recognition one category against rest
pattern = '[^-]+' #regex pattern to use if category is not complete filename



# Here begins the code.

print 'Reading in the data...'
trialdata = ptp.process_data(categoriesdir,stopwordfilename,cattocheck,pattern)

start = time.time()
# these are for evaluating correctness of classification
hitlist = []
guesslist = []
correctguesslist = []
totalslist = []
catinsamplelist = []
# these are for saving correctly classified trials
# ... as well as borderline cases and incorrectly included trials
difflist = []
closetrials = []
falsepositives = []
correctlyclassified = []


run = 0 # this is used to decide which sample is test set
for fn in sorted(os.listdir(sampledirname)):

    hits = 0
    guesses = 0
    correctguesses = 0
    total = 0
    catinsample = 0
    
    # split train and test 
    print 'Creating train and test sets, run {0}'.format(run)
    trainsetids, testsetids = ptp.create_sets(sampledirname,run)
    traindata, testdata = ptp.splittraintest(trainsetids,testsetids,trialdata)
    
    # train learner
    print 'Training learner, run {0}...'.format(run)
    mynb = nb.naivebayes()
    mynb.train(traindata)

    # test learner  
    print 'Testing learner, run {0}...'.format(run)
        
    for trialset in testdata:
        correctclass = trialset[0]
        for trial in trialset[1:]:
            result = mynb.classify(trial)
            guessedclass =  max(result, key=result.get)     
            if cattocheck:
                # check how sure we were of the classification
                # if we decided that this trial didn't belong but
                # weren't very sure, we save the trial id so we can
                # check it out later
                diff = abs(result[cattocheck] - result['other'])
                if diff < 10 and guessedclass != cattocheck:
                    closetrials.append(trial[0])
                    difflist.append(diff)
                # then record correctness of classification result
                # note that first version does a more complex evaluation 
                # ... for two-way (one class against rest) classification 
                if correctclass == cattocheck:
                    catinsample += 1
                if guessedclass == cattocheck:
                     guesses += 1
                     if guessedclass == correctclass:
                         hits += 1
                         correctlyclassified.append(trial[0])
                     else: 
                         falsepositives.append(trial[0])
            if guessedclass == correctclass:
                correctguesses += 1
                
            total +=1
                    
    hitlist.append(hits)
    guesslist.append(guesses)
    correctguesslist.append(correctguesses)
    totalslist.append(total)
    catinsamplelist.append(catinsample)
    run +=1

# Save close trials and false positives for closer examination
# note that if there are multiple offenses in a trial, we only save the first one listed
if cattocheck:
    print 'Saving correctly classified trials and close matches...'
    
    trialdict_fn = indirname + 'trialdict.json'
    corrects_fn = indirname + 'cclassedtrials.txt'    
    closetrials_fn = indirname + 'closetrials.txt'
    falseps_fn = indirname + 'falsepositives.txt'
    
    with open(trialdict_fn,'r') as f0:
        trialdict = json.loads(f0.read())
    
    # correctly classified trials  
    correctclasscats = [trialdict[trial] for trial in correctlyclassified]
    correctclasslist = zip(correctclasscats,correctlyclassified)
    correctclasslist.sort(key = lambda fpcats: fpcats[1]) #sort by trial id
    
    correcttxt = ''
    for idx,trial in enumerate(correctclasslist):
        line = '' + correctclasslist[idx][0] + ', ' + correctclasslist[idx][1] + '\n'
        correcttxt += line 
    with open(corrects_fn,'w') as f1:
        f1.write(correcttxt) 
    
    # close trials    
    closetrialcats = [trialdict[trial] for trial in closetrials]
    closetriallist = zip(closetrials,closetrialcats,difflist)
    closetriallist.sort(key=lambda diff: diff[2]) #sort by difference
    
    closetrialtxt= ''
    for idx,trial in enumerate(closetriallist):
        line = '' + closetriallist[idx][0] + ', ' + closetriallist[idx][1] + ', ' + str(closetriallist[idx][2]) + '\n'
        closetrialtxt += line 
    with open(closetrials_fn,'w') as f2:
        f2.write(closetrialtxt)    
      
    # false positives  
    falsepositivecats = [trialdict[trial] for trial in falsepositives]
    falsepositivelist = zip(falsepositivecats,falsepositives)
    falsepositivelist.sort(key = lambda fpcats: fpcats[0]) #sort by offense
    
    falseptxt = ''
    for idx,trial in enumerate(falsepositivelist):
        line = '' +  falsepositivelist[idx][0] + ', ' + falsepositivelist[idx][1] + '\n'
        falseptxt += line 
    with open(falseps_fn,'w') as f:
        f.write(falseptxt) 
    
# Calculate accuracy values
print 'Calculating accuracy of classification...'
if sum(guesslist) > 0:
    precision = sum(hitlist) / float(sum(guesslist))
else:
    precision = 0
if sum(catinsamplelist) > 0:
    recall = sum(hitlist) / float(sum(catinsamplelist))
else:
    recall = 0

accuracy =  sum(correctguesslist) / float(sum(totalslist))
avg_catinsample = sum(catinsamplelist) / float(run)
avg_total = sum(totalslist) / float(run)
end = time.time()
elapsedtime = end-start


# Finally, print results
if cattocheck:
    print 'Two-way classification, target category {0}.'.format(cattocheck)
else:
    print 'Multi-way classification'
print 'And the results are:'
print 'Accuracy {:.2%}'.format(accuracy)
if cattocheck:
    if sum(catinsamplelist) > 0:
        print 'Precision: {:.2%}'.format(precision) 
        print 'Recall: {:.2%}'.format(recall)
        print 'Average number of target category trials in test sample per run: {0}'.format(avg_catinsample)
    else:
        print 'No target category trials in test sample in any run. Maybe you should check that the category you are checking for exists?'
print 'Average number of trials in test sample per run: {0}'.format(avg_total)
print 'Obtained in {0:.2f} seconds'.format(elapsedtime)