import histogram
import matplotlib.pyplot as plt
import numpy as np

from ROOT import TFile, TH1D

TRUTH = '/user/msullivan/Wh_Upgrade_Studies/Studies/ttbar_truth_FULL.root'
TRUTH_VAR = '/user/msullivan/Wh_Upgrade_Studies/Studies/ttbar_truth.root'
RECO = '/user/msullivan/Wh_Upgrade_Studies/Studies/ttbar_reco_FULL.root'

def main():

    plot('weighted_TRUTH_MET', 'weighted_RECO_MET', r'$E_{T}^{miss}$ [GeV]')
    plot('weighted_TRUTH_MT', 'weighted_RECO_MT', r'$m_{T}$ [GeV]')
    plot('weighted_TRUTH_MCT', 'weighted_RECO_MCT', r'$m_{CT}$ [GeV]')
    plot('weighted_TRUTH_MBB', 'weighted_RECO_MBB', r'$m_{bb}$ [GeV]')

def root_to_hist(_tfile, _hist):

    tfile = TFile(_tfile)
    hist = tfile.Get(_hist)

    nbins = hist.GetXaxis().GetNbins()
    xlow = hist.GetXaxis().GetBinLowEdge(1)
    xhigh = hist.GetXaxis().GetBinLowEdge(nbins+1)

    mhist = histogram.histogram(nbins, xlow, xhigh)

    for bin in range(nbins+1):

        mhist.fill(hist.GetBinCenter(bin), hist.GetBinContent(bin))

    return mhist    

def plot(_truth_hist, _reco_hist, xlabel):

    truth_hist = root_to_hist(TRUTH, _truth_hist)
    truth_var_hist = root_to_hist(TRUTH_VAR, _truth_hist)
    reco_hist = root_to_hist(RECO, _reco_hist)

    truth_hist.normalise()
    truth_var_hist.normalise()
    reco_hist.normalise()

    plt.figure(0)
    plt.gcf().clear()
    plt.hist(truth_hist.bin_centres, truth_hist.bin_edges, weights = truth_hist.bin_wgt_contents, histtype='step', label = 'Truth-level')
    plt.hist(truth_var_hist.bin_centres, truth_var_hist.bin_edges, weights = truth_var_hist.bin_wgt_contents, histtype='step', label = 'Truth-level with smearing')
    plt.hist(reco_hist.bin_centres, reco_hist.bin_edges, weights = reco_hist.bin_wgt_contents, histtype='step', label = 'Reco-level')

    # Limits
    plt.xlim(0, 500)
    plt.ylim(1E-4, 0.5)
    if 'MCT' in _truth_hist:
        plt.xlim(0, 1000)
        plt.ylim(1E-5, 0.5)
    if 'MBB' in _truth_hist:
        plt.ylim(5E-3, 0.2)    

    # Cosmetics
    plt.yscale('log')
    plt.xlabel(xlabel)
    bin_width = int(truth_hist.bin_edges[1] - truth_hist.bin_edges[0])
    plt.ylabel('Events / {} GeV'.format(bin_width))
    plt.legend(loc='upper right')

    # Save
    plt.savefig('{}.png'.format(_truth_hist), dpi=500, quality=95)


if __name__ == '__main__':
    main()