# -*- coding: utf-8 -*-
##################################################################
# Various small routines to help with testing machine learners  ##
##################################################################
import re, os, sys
        
    
def tokenize_text(text,stopwordfilename):
    # In: a text (string) 
    # Out: words of text in a list, lowercased and with stopwords removed

    with open(stopwordfilename) as f:
        stopwords = [line.strip() for line in f]
    
    thistxttokens = text.split()
    
    # tokenize and remove stopwords and single-letter words, with stemming or not
    thistxttokens = [w.lower() for w in thistxttokens if w.lower() not in stopwords and len(w) > 1]
    return thistxttokens


def process_data(categoriesdir,stopwordfilename,cattocheck=None,pattern=None):
    # In: a directory where there are text files, the name of each of which
    # ... contains a category for classification
    # 'pattern' can define a regex pattern for identifying category in filename
    # Out: a list of lists containing category name (first elem), followed by a 
    # ... list of documents (each tokenized, i.e., words as list)
    # [[catname,['words','of','first,'doc]['words','of','second','doc]]]
    data = []
    sys.stdout.flush()
    for item in os.listdir(categoriesdir):
        if os.path.isfile(categoriesdir+item) and item.endswith('.txt'):
            lines = [line.strip() for line in open(categoriesdir+item)]
            if (pattern):
                catm = re.search(pattern,item)
                catname = catm.group(0)
            else:
                catname = item[:-4] # ([:-4] to remove .txt)
            if (cattocheck):
                if (cattocheck == catname):
                    category = catname
                else:
                    category = 'other'
            else:   
                category = catname
            docs = [category]
            for f in lines:
                features = tokenize_text(f,stopwordfilename)
                docs.append(features)
            data.append(docs)
    return data


def create_sets(sampledirname,run=0):
    # In: name of directory with text files, each of which contains
    # ... ids (filenames) of a portion of data to be trained on
    # e.g. 10-fold cross-validation samples
    # Out: two lists, one with trainset filenames and other with 
    # ... testset filenames
    trainset = []
    testset = []
    for fn in sorted(os.listdir(sampledirname)):
        currentfile = sampledirname + fn
        with open(currentfile) as f1:
            thesetrials = [line.strip() for line in f1]
            if str(run) in fn:
                testset += thesetrials
            else:
                trainset += thesetrials
    return trainset, testset


def splittraintest(trainsetids,testsetids,data):
    # In: list with fileids for trainset, list with fileids for testset
    # ...(produced by create_sets)
    # Out: two lists, of the same format as produced by process_data above
    # ... but one has training sample and the other test samples
    traindata = []
    testdata = []
    for index, cat in enumerate(data):
        traintrials = []
        testtrials = []
        # in each category, cat[0] is the name of the category
        # thus enumerate actual trials only
        for index2, trial in enumerate(cat[1:]):
            if (trial[0] in trainsetids):
                traintrials.append(trial)
            else:
                testtrials.append(trial)
        if (traintrials): #just in case it's empty
            traintrials = [cat[0]] + traintrials
            traindata.append(traintrials)
        if (testtrials):  
            testtrials = [cat[0]] + testtrials
            testdata.append(testtrials)
    return traindata, testdata
    