#!/usr/bin/env python
# -*- coding: utf-8 -*-                                                                                                                                                               
import os,sys
from array import array
import inspect                                                                                                                                     
import StringIO
from collections import OrderedDict 
from ROOT import gROOT 
from ROOT import RooWorkspace
from ROOT import RooAbsData
from ROOT import RooDataSet
from ROOT import RooFit
from math import *
from ROOT import kBlack,kWhite,kGray,kRed,kPink,kMagenta,kViolet,kBlue,kAzure,kCyan,kTeal,kGreen,kSpring,kYellow,kOrange,kDashed,kSolid,kDotted
from ROOT import TTree, TMath, TChain, TF1, TH1I, TH1F, TH1D, TH2F, TH2D, TGraphErrors, TGraph, THStack, TLegend, TGraph2D, TLatex
from ROOT import TFile, TCanvas, TPad, TLine, TLorentzVector
from ROOT import gPad,gStyle                                                                   

from AtlasStyle import AtlasStyle


def main(): 
    from optparse import OptionParser 
    parser = OptionParser(usage = "usage: %prog arguments", version="%prog") 
    parser.add_option('-b','--inputdirBkg', dest='inputdirBkg',                           help='input file for bkg  (default: %default)')
    parser.add_option('-s','--inputdirSig', dest='inputdirSig',                           help='input file for sig  (default: %default)')
    parser.set_defaults(inputdirBkg='/lhome/ific/c/casflo/work/monopy_new/monopy/plots/results/monotop/',inputdirSig='/lhome/ific/c/casflo/work/monopy_new/monopy/plots/results/Signal_tj/' )
    (options,args) = parser.parse_args()




    #BkgColors = dict(val1=('Top', kOrange - 2, "t#bar{t}"), val2 = ('Wjets', kCyan +2, "W+jets"), val3 = ('smallbkg', 2, "Others"))


    #SigColors = { 'a250_DM10_H1750_tb0p3': [1, "m_{a}=250 GeV m_{H}=1750 GeV tan#beta=0.3"], 'a250_DM10_H1000_tb0p3' : [413, "m_{a}=250 GeV m_{H}=1000 GeV tan#beta=0.3"]}
        
    BkgColors = dict(val1=('Top1L',kOrange - 2, "t#bar{t}"), val2=('Wjets',kCyan +2, "W+jets"), val3=('SingleTop', kOrange+2, "Single top quark"), val4 =("Diboson1L",kMagenta-4, "Diboson"), val5 = ("ttV1L", kAzure-3, "t#bar{t}V"), val6 = ("Zjets", kYellow, "Z+jets"), val7 = ("tWZ", kYellow, "tWZ"), val8 = ("ttH", kYellow, "t#bar{t}h"))    

    print( 'dotW1L' )
        
    SigColors = {'a250_DM10_H600_tb2_st0p7_1L0L' : [413, "m_{a}=250 GeV m_{H^{#pm}}=600 GeV tan#beta=2"], 'a300_DM10_H1200_tb1_st0p7_1L0L' : [1, "m_{a}=300 GeV m_{H^{#pm}}=1200 GeV tan#beta=1"]}

    
    #plotVars(options.inputdirBkg,options.inputdirSig,"SR",BkgColors, SigColors, "afterFit","BDT", "tj-1L SR","Events","BDT output",[0.6, 0.75, 0.85, 0.9, 1.0], "log", {'leg' : [0.32, 0.55, 0.79, 0.87, 0.15]})
    plotVars(options.inputdirBkg,options.inputdirSig,"SR1L", BkgColors, SigColors, "afterFit", "met1000", "SR_{tW_{1L}}","Events / bin width", "E_{T}^{miss} [GeV]", [250, 300, 400, 500, 600, 1000], "log", {'leg' : [0.36, 0.46, 0.85, 0.85, 0.14]})

#  {'leg' : [0.43, 0.55, 0.87,0.87, 0.15]})
    #plotVars(options.inputdirBkg,options.inputdirSig,"SR",BkgColors, "", "afterFit","BDT", "Events","BDT output",[0.6, 0.75,0.85, 0.9, 1.0], "lin")  
 


# function to get unique values 
def unique(list1): 
  
    # intilize a null list 
    unique_list = [] 
      
    # traverse for all elements 
    for x in list1: 
        # check if exists in unique_list or not 
        if x not in unique_list: 
            unique_list.append(x) 

    # print list 
    return unique_list
        


# Function to getEroors
def getError(h_object) :

    x  = array('d',[0])
    y = array('d',[0])
    
    values = []
    for i in range(0,15) :
        h_object.GetPoint(i,x,y)
        if y[0] :
            values.append(y[0])

    max_min = []
    for v in unique(values) :
        if values.count(v) == 5 :
            max_min.append(v)

    return sorted(max_min)
    



def plotVars(folderBkg, folderSig, region, bkg,signal,fit, name_vars, caption,ylabel, xlabel, binning, scale, style_plot) :
   

    style = AtlasStyle()
    c, pad1, pad2 = createCanvasPads()
    #c = TCanvas('canvas',"")
    if style_plot:
        leg = TLegend(style_plot['leg'][0]+0.12,style_plot['leg'][1],style_plot['leg'][2], style_plot['leg'][3])
        if signal:
                leg1 = TLegend(style_plot['leg'][0],style_plot['leg'][1]-0.12,style_plot['leg'][2], style_plot['leg'][3]-0.38)
        #else: 
        #   leg1 = TLegend(0.65, 0.60, 0.70, 0.70)
        if signal:
            leg1.SetTextSize(0.040)
            leg1.SetFillStyle(0)
            leg1.SetMargin(0.22)
            leg1.SetFillColor(0)
            leg1.SetBorderSize(0)
            leg1.SetTextFont(42)              

    else:
        leg = TLegend(0.65, 0.60, 0.83, 0.87)
    leg.SetTextSize(0.040)
    leg.SetFillStyle(0)
    leg.SetMargin(0.22)
    leg.SetFillColor(0)
    leg.SetBorderSize(0)
    leg.SetTextFont(42)

    

    pad1.cd()
    
    #c.cd()
    # Here you setup an array for you binning
    bini = array('d',binning) 
    # The number of bins is equal to the regions defined
    nbins = len(bini)-1
    
    rootFile = {}

    # Here define the root files, like data and all bkg
    data = {}
    # Background to array the bins content 
    top = {}
    wjets = {}
    small = {}

    # All SM proceses
    SM = {}

    # Error, rel is applied in the ratio
    total_err= {}
    rel_err = {}
    
    # Array the ratio content
    ratio = {}
    rootFileSig = {}
    ##############################################
    ## Getting the bin content 
    ##############################################
    bkgHisto = {}
    signalHisto = {}
    for i in range(1, nbins+1) :
        name = folderBkg+str(region)+"Bin"+str(i-1)+"_cuts_"+str(fit)+".root"
        rootFile[i] =  TFile(name)
        rel_err[i] = getError(rootFile[i].Get("h_rel_error_band"))
        ratio[i] = rootFile[i].Get("h_ratio")
        data[i] = rootFile[i].Get("h_obsData")
        SM[i] = rootFile[i].Get("SM_total")
        total_err[i] = getError(rootFile[i].Get("h_total_error_band"))
        for j in bkg:
            bkgHisto[str(bkg[j][0])+str(i)]= {}
            bkgHisto[str(bkg[j][0])+str(i)] = rootFile[i].Get(str(bkg[j][0]))

        if signal:
            inputFileSignal = folderSig+str(region)+"Bin"+str(i-1)+'_cuts_beforeFit.root'
            rootFileSig[i] =  TFile(inputFileSignal)
            for j in signal:
                signalHisto[str(j)+str(i)] = {}
                signalHisto[str(j)+str(i)] = rootFileSig[i].Get(str(j))

            
   
    ##############################################
    # Here are defined the Histograms
    ##############################################
    data_all = TH1F("","",len(bini)-1,bini)
    SM_all = TH1F("","",len(bini)-1,bini)
    ratio_all = TH1F("","",len(bini)-1,bini)
    bkg_all = {}
    sig_all = {}
    for j in bkg:
        bkg_all[bkg[j][0]] = TH1F("","",len(bini)-1,bini)
        bkg_all[bkg[j][0]].SetFillColor(bkg[j][1])
        bkg_all[bkg[j][0]].SetLineWidth(0)
        

    for sg in signal:
        sig_all[sg] = TH1F("","",len(bini)-1,bini)
        sig_all[sg].SetFillColor(signal[sg][0])
        sig_all[sg].SetLineWidth(2)
       
   

        
    # Stack
    hs = THStack()

    # Double definition, to read data points
    x  = array('d',[0])
    y = array('d',[0])
    y0 = array('d',[0])
            


    SM_all.SetMarkerStyle(1)
    SM_all.SetFillColor(kBlack)
    SM_all.SetLineColor(kBlack)
    SM_all.SetFillStyle(3004)
    SM_all.SetLineWidth(3)
    SM_all.SetLineColor(kRed+2)
    data_all.SetLineWidth(1)

    

    # Fill the bin content
    for i in range(1, nbins+1):
        # Data
        data[i].GetPoint(0,x,y)
        data_all.SetBinContent(i, float(y[0]))
       
        # ratio
        ratio[i].GetPoint(0,x,y0)
        ratio_all.SetBinContent(i, float(y0[0]))
        ratio_all.SetBinError(i, ratio[i].GetErrorY(0))
        # SM
        SM_all.SetBinContent(i, SM[i].GetBinContent(1))
        SM_all.SetBinError(i, abs(total_err[i][0]- SM[i].GetBinContent(1)))
        for j in signal:
            sig_all[j].SetBinContent(i, signalHisto[str(j)+str(i)].GetBinContent(1))
            

    
        for j in bkg:
            try :
                #print bkgHisto[str(j)+str(i)].GetBinContent(1)
                bkg_all[bkg[j][0]].SetBinContent(i, bkgHisto[str(bkg[j][0])+str(i)].GetBinContent(1))
            except:
                pass



        
                
    c.Update

    min_bin = binATLAS(data_all)
    binATLAS(SM_all)
    # Adding Data in Legend
    leg.AddEntry(data_all, "Data", "PX0E")
   
    leg.AddEntry(SM_all, "SM Total", "fl")
    
    for j in bkg:
        binATLAS(bkg_all[bkg[j][0]])

    
    # Stack add your SM process here
    for j in reversed(sorted(bkg)):
        hs.Add(bkg_all[bkg[j][0]])


        


    maxY = hs.GetMaximum()
    if signal:
        for sg in signalHisto :
            binATLAS(signalHisto[sg])
            if signalHisto[sg].GetMaximum() > maxY:
                maxY = signalHisto[sg].GetMaximum() 



                

    if "BDT" in str(xlabel):
        ylabel = str(ylabel) #+ " / " +str(min_bin)
    else:
        ylabel = str(ylabel) #+ " / " + "bin GeV"

    
    hs.Draw("HIST")
    hs.GetYaxis().SetTitle(ylabel)
    hs.GetYaxis().SetTitleOffset(0.9)                                                                            
    hs.SetMaximum(2 * maxY)
    hs.GetXaxis().SetLabelSize(0.)
    hs.GetYaxis().SetTitleSize(0.055)
    hs.GetYaxis().SetLabelSize(0.05)
    


    if scale == "log":
        hs.SetMaximum(1000. * maxY)
        hs.SetMinimum(1)
        gPad.SetLogy()   

     
    #hs.SetMaximum(375)
    hs.Draw("HIST F")
    
    style_l = 2
    factor = 1
    if signal:
        for sg in signal :
            sig_all[sg].SetLineColor(signal[sg][0])
            sig_all[sg].SetFillColor(0)
            sig_all[sg].SetLineWidth(3)
            sig_all[sg].SetMarkerStyle(0)
            sig_all[sg].SetLineStyle(style_l)
            #sig_all[sg].Scale(sig_all[sg].Integral()*factor)
            sig_all[sg].GetYaxis().SetRangeUser(1, maxY)
            sig_all[sg].Draw("HIST SAME")
            
            #style_l +=1

    

    SM_clone = SM_all.Clone("SM_clone")
    SM_clone.SetFillStyle(0)
    SM_clone.Draw("HIST SAME")
    SM_all.Draw("E2 SAME")
    data_all.Draw("P E1 X0 SAME")
    SM_all.Draw("AXIS SAME")
    
    #for j in bkg:
    #    leg.AddEntry(bkg_all[j], bkg[j][1], "f")

    #for i, e in reversed(list(enumerate(bkg))):
    for e in sorted(bkg):
        if e == "val6":
            leg.AddEntry(bkg_all[bkg[e][0]], "Others", "f")
            continue
        if e == "val7" or e == "val8":
            continue
        leg.AddEntry(bkg_all[bkg[e][0]], bkg[e][2], "f")


        
    if signal:
        for sg in reversed(list(signal)) :
            print(sg)
            if style_plot:
                #leg1.AddEntry(sig_all[sg], signal[sg][1], "l")
                #leg1.Draw()
                print('Skipping leg')
            else:
                leg.AddEntry(sig_all[sg], signal[sg][1], "l")
        leg1.AddEntry(sig_all['a300_DM10_H1200_tb1_st0p7_1L0L'], signal['a300_DM10_H1200_tb1_st0p7_1L0L'][1], 'l')
        leg1.AddEntry(sig_all['a250_DM10_H600_tb2_st0p7_1L0L'], signal['a250_DM10_H600_tb2_st0p7_1L0L'][1], 'l')
        leg1.Draw()
            

    #leg.AddEntry(SM_all, "Uncertainty", "f") 
    if style_plot and signal:
        #custom_text(style_plot['leg'][4], style_plot['leg'][3]-0.03, "#bf{#it{ATLAS}} %s" % "Internal", 0.065, 1)
        custom_text(style_plot['leg'][4], style_plot['leg'][3]-0.03, "#bf{#it{ATLAS}} %s" % "Preliminary", 0.065, 1)
        style.myText(style_plot['leg'][4],style_plot['leg'][3]-0.09,color=1,size=0.055,text="#sqrt{s} = 13 TeV, 139 fb^{#minus1}")
        style.myText(style_plot['leg'][4],style_plot['leg'][3]-0.15,color=1,size=0.055,text="{} {}".format(str(caption),"#bf{Post-fit}"))
    elif style_plot:
        #custom_text(style_plot['leg'][4], style_plot['leg'][3]-0.03, "#bf{#it{ATLAS}} %s" % "Internal", 0.065, 1)
        custom_text(style_plot['leg'][4], style_plot['leg'][3]-0.03, "#bf{#it{ATLAS}} %s" % "Preliminary", 0.065, 1)
        style.myText(style_plot['leg'][4],style_plot['leg'][3]-0.09,color=1,size=0.055,text="#sqrt{s} = 13 TeV, 139 fb^{#minus1}")
        style.myText(style_plot['leg'][4],style_plot['leg'][3]-0.15,color=1,size=0.055,text="{} {}".format(str(caption),"#bf{Post-fit}"))
    else: 
        #custom_text(0.2, 85, "#bf{#it{ATLAS}} %s" % "Internal", 0.065, 1)
        custom_text(0.2, 85, "#bf{#it{ATLAS}} %s" % "Preliminary", 0.065, 1)
        style.myText(0.2,0.80,color=1,size=0.055,text="#sqrt{s} = 13 TeV, 139 fb^{#minus1}")
        style.myText(0.2,0.75,color=1,size=0.055,text="{} {}".format(str(caption),"#bf{Post-fit}"))


    #custom_text(0.2, 0.86, "#bf{#it{ATLAS}} %s" % "Internal", 0.045, 1)
    #style.myText(0.2,0.80,color=1,size=0.045,text="#sqrt{s} = 13 TeV, 139 fb^{#minus1}")
    #style.myText(0.2,0.74,color=1,size=0.045,text="SR + BDT cut")

    leg.Draw()


    pad2.cd()

    ##############################################
    # Create the ratio
    ##############################################
    
    err, h = createRatio(ratio_all, rel_err, 1, xlabel)




    err.SetFillStyle(3004)
    err.SetFillColor(kBlack)
    err.DrawCopy("e2")
    
    ##############################################
    # More style
    ##############################################
    l = TLine(h.GetXaxis().GetXmin(),1, h.GetXaxis().GetXmax(),1);
    l2 = TLine(h.GetXaxis().GetXmin(),0.5,h.GetXaxis().GetXmax(),0.5)
    l3 = TLine(h.GetXaxis().GetXmin(),1.5, h.GetXaxis().GetXmax(),1.5)
    l4 = TLine(h.GetXaxis().GetXmin(),2.,h.GetXaxis().GetXmax(),2.)
    #l5 = TLine(xmin, 2.5, xmax, 2.5);
    l.SetLineWidth(1)
    l.SetLineStyle(2)
    l2.SetLineStyle(3)
    l3.SetLineStyle(3)
    l4.SetLineStyle(3)
    #l5.SetLineStyle(3)

    l.Draw("same")
    l2.Draw("same")
    l3.Draw("same")
    l4.Draw("same")
    h.Draw("p e1 X0 F same")
    

    c.Update     
    c.SaveAs(name_vars+"_"+fit+".pdf")
    
    return 





def createRatio(ratio,err_ratio, color, name):


    err = ratio.Clone("err")
    for x in range(1,len(err_ratio)+1) :
        err.SetBinContent(x,1)
        err.SetBinError(x,abs(err_ratio[x][0]-1))

    h3 = ratio.Clone("h3")
    err.GetXaxis().SetTitle(name)
    h3.SetLineColor(color)
    h3.SetMarkerStyle(20)
    h3.SetMarkerColor(color)
    h3.SetTitle("")
    h3.SetStats(0)
    h3.SetLineWidth(1)    
    h3.Sumw2()

    # Adjust x-axis settings
    
    
    #x.SetTickLength(0.06)
    MyLowerLimit = 0.1
    MyUpperLimit = 2.0
    h3.GetYaxis().SetRangeUser(MyLowerLimit, MyUpperLimit)
    err.GetYaxis().SetRangeUser(MyLowerLimit, MyUpperLimit)
    err.GetYaxis().SetNdivisions(504)
    err.SetMarkerStyle(0)
    err.GetYaxis().SetTitle("Data/SM")
    err.GetXaxis().SetTitleSize(0.125)
    err.GetYaxis().SetTitleSize(0.125)
    err.GetXaxis().SetTitleOffset(1.0)

    err.GetYaxis().SetTitleOffset(0.40)

    err.GetXaxis().SetLabelSize(0.125)
    err.GetYaxis().SetLabelSize(0.125)

    err.GetXaxis().SetLabelOffset(0.04)
    err.GetXaxis().SetTitleOffset(1.1)
    
    err.GetYaxis().SetNdivisions(504)
    err.GetYaxis().CenterTitle()





    return err, h3


def createCanvasPads():

    c= TCanvas("Test",'test', 600, 600)
    c.Update()     
    
     

    fPad1 = TPad("fPad1", "fPad1",0., 0.305, .99, 1, 0);
    fPad1.SetBottomMargin(0.02)
    fPad1.SetTopMargin(0.1)
    fPad1.SetRightMargin(0.1)
    fPad1.SetLeftMargin(0.1)
    fPad1.SetFillColor(0)
    fPad1.SetTickx()
    fPad1.SetTicky()
    fPad1.Draw()



    #fPad2 = TPad("fPad2", "fPad2", 0.0, 0.0, 1.0, 0.30, 0)
    fPad2 = TPad("fPad2", "fPad2", 0., 0.01, .99, 0.295, 0)
    fPad2.SetTopMargin(0.05)
    fPad2.SetBottomMargin(0.3)
    fPad2.SetRightMargin(0.1)
    fPad2.SetLeftMargin(0.1)
    fPad2.SetFillColor(0)
    fPad2.Draw()   
    fPad2.SetFillColor(0)

    return c, fPad1, fPad2
    


def custom_text(x, y, text, size, color):
    l = TLatex()
    # l.SetTextAlign(12)                                                                                                
    l.SetTextFont(42)
    l.SetTextSize(size)
    l.SetTextColor(color)
    l.SetNDC()
    l.DrawLatex(x,y,text)

    return



def binATLAS(histo):

    min_bin = 10000
    for j in range(1,histo.GetNbinsX() + 1):
        if min_bin > histo.GetBinWidth(j) :
            min_bin = histo.GetBinWidth(j)

    #print min_bin
    
    histo.Scale(min_bin, "width")
    
        
    return min_bin




    

if __name__ == '__main__':   
    main()                   
