import matplotlib.pyplot as plt
import numpy as np
from ROOT import *

fitfun = 'fitfun_mu200_mv2c10_70_2'

def main():

    ftag_plot('ftag_test')

def ftag_plot(name):

    x = np.linspace(10, 1000, 1000)
    y = np.linspace(0.0, 4.5, 1000)
    z = np.zeros([len(x), len(y)])

    tfile = TFile('flavor_tags_v2.0.root')
    ftag_func = tfile.Get(fitfun)
    pt_limit = tfile.Get('pt_limits')

    for pt in range(len(x)):
        for eta in range(len(y)):
            ftag_eff = ftag_func.Eval(x[pt], y[eta]) * high_pt_corr(x[pt], 'B')
            z[eta, pt] = ftag_eff 
    
    print(pt_limit.Eval(2))    


    fig = plt.figure()
    ax = plt.subplot(111)
    pcm = ax.pcolormesh(x,y,z, vmin=0.0, vmax=1.0)#, vmin=1E-2, vmax=1, norm=mplcolors.LogNorm())#, vmin = minval, vmax = maxval, norm=mplcolors.LogNorm())
    fig.colorbar(pcm, ax=ax)
    ax.set_xlabel(r'Jet $p_{T}$ [GeV]')
    ax.set_ylabel(r'Jet $\left|\eta\right|$')
    plt.savefig('{}.png'.format(name), dpi=500, quality=95)    

def high_pt_corr(pt, flavour):

    if pt > 1000:
        pt = 1000
    p1 = 0
    if (flavour == 'B'):
        p1 = -6.17434e-04;
    elif (flavour == 'C'):
        p1 = -7.20231e-04;
    else:# /* L,P */ 
        p1 = 2.78547e-04;
    highpt_factor = 1 + p1 * (pt - 300)
    return highpt_factor


if __name__=='__main__':
    main()