import threading
import argparse
import os
import multiprocessing as mp
import numpy as np
import merge_parallel as tbmerge
from ROOT import TFile, TTree

def main():

    parser = argparse.ArgumentParser(description='Script for merging track collections for TBMon2 analysis')
    parser.add_argument('-f', dest='files', action='append')
    parser.add_argument('-r', dest='run', action='append')
    args = parser.parse_args()

    tfile = TFile(args.files[0])
    ttree = tfile.Get('tracks')
    run = args.run
    entries = ttree.GetEntries()
    n_cpus = mp.cpu_count() 
    pool = mp.Pool(n_cpus)
    entrieslist, boundarieslist = divide_files(entries, n_cpus)

    results = [pool.apply_async(mergethread, (args.files, boundarieslist[i], entrieslist[i], run, i)) for i in range(n_cpus)]
    pool.close()
    pool.join()

def mergethread(files, start, entries, run, index):

    tb = tbmerge.tbmon_merge()
    start = int(start)
    entries = int(entries)
    outputname = str(run) + '_' + 'output'+str(index)
    tb.run(files, outputname, start, entries)

def divide_files(entries, n_cpus):
    entrieslist = []
    boundarieslist = []
    for i in range(n_cpus):
        if i != n_cpus - 1:
            entrieslist.append(int(entries/n_cpus))
        else:
            entrieslist.append(int(entries - (int(entries/n_cpus)*(n_cpus-1))))
    runningtotal = 0
    for i in range(n_cpus):
        boundarieslist.append(runningtotal)
        runningtotal += entrieslist[i]
    return entrieslist, boundarieslist                

if __name__=='__main__':
    main()    
