import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from ROOT import TFile, TH2F

p1_ID_MC = []
p2_ID_MC = []
p2_ID_TAN2_MC = []
p0_MS_MC = []
p1_MS_MC = []
p2_MS_MC = []

def main():

    mu_eff('/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/UpgradePerformanceFunctions/CalibArea-00-01/Muon_Eff_Tight.root', 'tight')
    mu_eff('/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/UpgradePerformanceFunctions/CalibArea-00-01/Muon_Eff_HighPt.root', 'high-pt')
    mu_eff('/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/UpgradePerformanceFunctions/CalibArea-00-01/Muon_Eff_Loose.root', 'loose')

    sys.exit()

    p1_ID_MC, p2_ID_MC, p2_ID_TAN2_MC, p0_MS_MC, p1_MS_MC, p2_MS_MC = mu_res()

    x = np.linspace(5.0, 1000, 500)
    y = np.linspace(-4.0, 4.0, 500, endpoint=False)
    z = np.zeros([len(x), len(y)])

    for pt in range(len(x)):
        for eta in range(len(y)):
            sigma_CB = comb_res(x[pt], y[eta], 0.0)
            z[eta, pt] = sigma_CB

    fig = plt.figure()
    ax = plt.subplot(111)
    pcm = ax.pcolormesh(x,y,z, norm=mpl.colors.LogNorm())#, vmin=0.0, vmax=1.0)#, vmin=1E-2, vmax=1, norm=mplcolors.LogNorm())#, vmin = minval, vmax = maxval, norm=mplcolors.LogNorm())
    cb = fig.colorbar(pcm, ax=ax)
    cb.set_label(r'Muon $p_{T}$ resolution [GeV]')
    ax.set_xlabel(r'Muon $p_{T}$ [GeV]')
    ax.set_ylabel(r'Muon $\left|\eta\right|$')
    plt.savefig('muon_res.png', dpi=500, quality=95)  

def mu_eff(m_file, name):

    tfile = TFile(m_file)
    mc_eff = tfile.Get('MC_Eff_All')

    x = np.linspace(10, 1000, 1000)
    y = np.linspace(0.0, 4.0, 1000)
    z = np.zeros([len(x), len(y)])

    for pt in range(len(x)):
        for eta in range(len(y)): 
            if np.abs(y[eta]) > 2.7:
                if x[pt] > 10:
                    z[eta, pt] = 0.95
                else:
                    z[eta, pt] = 0.0 
            else:
                z[eta, pt] = mc_eff.GetBinContent(mc_eff.FindBin(y[eta], x[pt]))

    fig = plt.figure()
    ax = plt.subplot(111)
    pcm = ax.pcolormesh(x,y,z, vmin=0.5, vmax=1.0)#,cmap=plt.get_cmap('copper'),  vmin=1E-2, vmax=1, norm=mplcolors.LogNorm())#, vmin = minval, vmax = maxval, norm=mplcolors.LogNorm())
    cb = fig.colorbar(pcm, ax=ax)
    cb.set_label(r'Efficiency')
    ax.set_xlabel(r'Muon $p_{T}$ [GeV]')
    ax.set_ylabel(r'Muon $\left|\eta\right|$')
    plt.savefig('mu_eff_{}.png'.format(name), dpi=500, quality=95)  

def comb_res(pt, eta, phi):

    sigma_ID = itk_res(pt, eta, phi)
    sigma_MS = ms_res(pt, eta, phi)

    if np.abs(eta) > 2.7:
        return sigma_ID
    if (sigma_ID > 0.0):
        sigma_CB = sigma_ID*sigma_MS/np.sqrt(sigma_ID*sigma_ID+sigma_MS*sigma_MS)
    else:
        sigma_CB = sigma_MS
    return sigma_CB          

def itk_res(pt, eta, phi):

    if pt == 0.0:
        return 1E9

    p1_ID_MC, p2_ID_MC, p2_ID_TAN2_MC, p0_MS_MC, p1_MS_MC, p2_MS_MC = mu_res()

    res = 0.0
    region = get_region(eta, phi)
    if region == -99:
        return 0.0
    p1 = p1_ID_MC[region]
    p2 = p2_ID_MC[region]
    res = np.sqrt( (p1)**2 + (p2*pt)**2 )
    return pt * res

def ms_res(pt, eta, phi):

    if pt == 0.0:
        return 1E9

    p1_ID_MC, p2_ID_MC, p2_ID_TAN2_MC, p0_MS_MC, p1_MS_MC, p2_MS_MC = mu_res()

    res = 0.0
    region = get_region(eta, phi)
    if region == -99:
        return 0.0
    p0 = p0_MS_MC[region]
    p1 = p1_MS_MC[region]
    p2 = p2_MS_MC[region]
    res = np.sqrt( (p0/pt)**2 + (p1)**2 + (p2*pt)**2 )
    return pt * res       

def mu_res():

    line_offset = 2

    p1_ID_MC = []
    p2_ID_MC = []
    p2_ID_TAN2_MC = []
    p0_MS_MC = []
    p1_MS_MC = []
    p2_MS_MC = []

    with open('mu_res.dat') as mu_res:
        index = 0
        for line in mu_res:
            if index >= line_offset:
                line = line.strip().split('      ')
                p1_ID_MC.append(float(line[0]))
                p2_ID_MC.append(float(line[1]))
                p2_ID_TAN2_MC.append(float(line[2]))
                p0_MS_MC.append(float(line[3]))
                p1_MS_MC.append(float(line[4]))
                p2_MS_MC.append(float(line[5]))

            index += 1    

    return p1_ID_MC, p2_ID_MC, p2_ID_TAN2_MC, p0_MS_MC, p1_MS_MC, p2_MS_MC

def get_region(eta, phi):

    if not (-4.0 < phi and phi < 4.0):
        return -99
    eta_bins = [-4.0, -3.8, -3.6, -3.4, -3.2, 
                -3.0, -2.7, -2.5, -2.3, -2.0, 
                -1.7, -1.5, -1.25, -1.05, -0.8, 
                -0.4, 0.0, 0.4, 0.8, 1.05,
                1.25, 1.5, 1.7, 2.0, 2.3,
                2.5, 2.7, 3.0, 3.2, 3.4, 
                3.6, 3.8, 4.0]
    for i in range(0, len(eta_bins)):
        if eta >= eta_bins[i] and eta < eta_bins[i+1]:
            return i  
    return -99        

if __name__ == '__main__':
    main()    