import root_to_hist as rth
import matplotlib.pyplot as plt

TRUTH_FILE = '/user/msullivan/Wh_Upgrade_Studies/Studies/ttbar_truth_FULL.root'
TRUTH_VAR_FILE = '/user/msullivan/Wh_Upgrade_Studies/Studies/ttbar_truth.root'
RECO_FILE = '/user/msullivan/Wh_Upgrade_Studies/Studies/ttbar_reco_FULL.root'

axis_labels = {
    "weighted_TRUTH_MET": "E_{T}^{miss} [GeV]",
    "weighted_TRUTH_MT": "m_{T} [GeV]"
}

def main():

    root_plot('weighted_TRUTH_MET', 'weighted_RECO_MET', density=True)
    root_plot('weighted_TRUTH_MT', 'weighted_RECO_MT', density=True)

def root_plot(_truth_hist_name, _reco_hist_name, density=False):

    from ROOT import TH1D, TCanvas, gStyle, TLegend

    ttbar_truth = rth.root_to_hist(TRUTH_FILE, _truth_hist_name)
    ttbar_truth_var = rth.root_to_hist(TRUTH_VAR_FILE, _truth_hist_name)
    ttbar_reco = rth.root_to_hist(RECO_FILE, _reco_hist_name)  

    truth_hist = ttbar_truth.get_hist()
    truth_var_hist = ttbar_truth_var.get_hist()
    reco_hist = ttbar_reco.get_hist()

    # Create new canvas
    c = TCanvas('c', 'c', 1600, 1200)

    # Get rid of stats box
    gStyle.SetOptStat(0)

    # Draw histograms
    truth_hist.Draw('HIST E0 X0')
    truth_var_hist.Draw('HIST E0 X0 SAME')
    reco_hist.Draw('HIST E0 X0 SAME')

    # If norm, do so!
    if density:
        truth_hist.Scale( 1./truth_hist.Integral() )
        truth_var_hist.Scale( 1./truth_var_hist.Integral() )
        reco_hist.Scale(1./reco_hist.Integral() )

    # Legend
    leg = TLegend(0.6, 0.6, 0.8, 0.8)
    leg.AddEntry(truth_hist, "Truth-level", 'l')
    leg.AddEntry(truth_var_hist, "Truth-level (no pileup jets)", 'l')
    leg.AddEntry(reco_hist, "Reco-level", 'l')
    leg.Draw()

    # Axis labels
    truth_hist.GetXaxis().SetTitle( axis_labels[_truth_hist_name] )

    # Cosmetics
    truth_hist.SetTitle('')
    truth_hist.SetLineColor(807)
    truth_hist.SetLineWidth(2)

    truth_var_hist.SetLineColor(877)
    truth_var_hist.SetLineWidth(2)

    reco_hist.SetFillColorAlpha(597, 0.3)
    #reco_hist.SetLineWidth(2)

    # Ranges
    truth_hist.GetXaxis().SetRangeUser(0,500)
    truth_hist.GetYaxis().SetRangeUser(1E-4,1)

    c.SetLogy()

    c.SaveAs('{}.pdf'.format(_truth_hist_name))

def mpl_plot():

    ttbar_truth_MET = rth.root_to_hist(TRUTH_FILE, 'weighted_TRUTH_MET')
    ttbar_reco_MET = rth.root_to_hist(RECO_FILE, 'weighted_RECO_MET')

    fig = plt.figure(0)
    plt.hist(ttbar_truth_MET.bins, ttbar_truth_MET.yields, weights=ttbar_truth_MET.weights, log=True, histtype='step', label='TRUTH', density=True)
    plt.hist(ttbar_reco_MET.bins, ttbar_reco_MET.yields, weights=ttbar_reco_MET.weights, log=True, histtype='step', label='RECO', density=True)
    plt.legend(loc='upper right')
    plt.xlabel('MET [GeV]')
    plt.ylabel('Events')
    fig.savefig('MET.png', dpi=500, quality=95)

    ttbar_truth_MBB = rth.root_to_hist(TRUTH_FILE, 'weighted_TRUTH_MBB')
    ttbar_reco_MBB = rth.root_to_hist(RECO_FILE, 'weighted_RECO_MBB')

    fig = plt.figure(1)
    plt.hist(ttbar_truth_MBB.bins, ttbar_truth_MBB.yields, weights=ttbar_truth_MBB.weights, log=True, histtype='step', label='TRUTH', density=True)
    plt.hist(ttbar_reco_MBB.bins, ttbar_reco_MBB.yields, weights=ttbar_reco_MBB.weights, log=True, histtype='step', label='RECO', density=True)
    plt.legend(loc='upper right')
    plt.xlabel('MBB [GeV]')
    plt.ylabel('Events')
    fig.savefig('MBB.png', dpi=500, quality=95)

if __name__=='__main__':
    main()    