import ROOT as R
from array import array

def alias(t):
    t.SetAlias("px", "McEventCollection_p5_BeamTruthEvent.m_genParticles.m_px")
    t.SetAlias("py", "McEventCollection_p5_BeamTruthEvent.m_genParticles.m_py")
    t.SetAlias("pz", "McEventCollection_p5_BeamTruthEvent.m_genParticles.m_pz")
    t.SetAlias("pt", "sqrt(px*px + py*py)")
    t.SetAlias("p", "sqrt(pz*pz + pt*pt)")     
    t.SetAlias("m", "McEventCollection_p5_BeamTruthEvent.m_genParticles.m_m")
    t.SetAlias("e", "sqrt(px*px + py*py + pz*pz + m*m)")
    t.SetAlias("th", "asin(pt/p)")
    t.SetAlias("ph", "acos(px/pt)")
    t.SetAlias("wt", f"McEventCollection_p5_BeamTruthEvent.m_genEvents.m_weights * 1/(137130000/80e12)") # Norm to 1 fb
    t.SetAlias("x", "McEventCollection_p5_BeamTruthEvent.m_genVertices.m_x")
    t.SetAlias("y", "McEventCollection_p5_BeamTruthEvent.m_genVertices.m_y")
    t.SetAlias("z", "McEventCollection_p5_BeamTruthEvent.m_genVertices.m_z")    
    t.SetAlias("r", "sqrt(x*x + y*y)")

    return

def set_legend(args):
    "Create a legend"
    
    if args == "TR":
        args = (0.68, 0.68, 0.86, 0.86)
    elif args == "TL":
        args = (0.15, 0.68, 0.35, 0.88)                    
        
    x1, y1, x2, y2 = args
    
    leg = R.TLegend(x1, y1, x2, y2)
    leg.SetBorderSize(0)
    leg.SetFillColor(0)
    leg.SetFillStyle(0)

    return leg

def main(t, charge, ebins = (100, 106, 4500e3), thbins = (140,0,0.14), theta = (), radii = ()):
    t.SetAlias("px", "McEventCollection_p5_BeamTruthEvent.m_genParticles.m_px")
    t.SetAlias("py", "McEventCollection_p5_BeamTruthEvent.m_genParticles.m_py")
    t.SetAlias("pz", "McEventCollection_p5_BeamTruthEvent.m_genParticles.m_pz")
    t.SetAlias("pt", "sqrt(px*px + py*py)")
    t.SetAlias("p", "sqrt(pz*pz + pt*pt)")     
    t.SetAlias("m", "McEventCollection_p5_BeamTruthEvent.m_genParticles.m_m")
    t.SetAlias("e", "sqrt(px*px + py*py + pz*pz + m*m)")
    t.SetAlias("th", "asin(pt/p)")
    t.SetAlias("ph", "acos(px/pt)")
    t.SetAlias("wt", f"McEventCollection_p5_BeamTruthEvent.m_genEvents.m_weights * 1/(137130000/80e12)") # Norm to 1 fb

    cut = ""

    if theta:
        cut += f"(th > {theta[0]} && th < {theta[1]}) "        
    
    if radii:        
        t.SetAlias("x", "McEventCollection_p5_BeamTruthEvent.m_genVertices.m_x")
        t.SetAlias("y", "McEventCollection_p5_BeamTruthEvent.m_genVertices.m_y")
        t.SetAlias("z", "McEventCollection_p5_BeamTruthEvent.m_genVertices.m_z")    
        t.SetAlias("r", "sqrt(x*x + y*y)")
        if cut:
            cut += " && " 
        cut += f"(r > {radii[0]} && r < {radii[1]})"

    if cut:
        cut = f" * ({cut})"

    print(cut)

    he = None
    if ebins is None:
        ebins2 = array("f", [0, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 3000])
        he = R.TH1F(f"energy_{charge}", f"energy_{charge}", len(ebins2)-1, ebins2)
        t.Draw(f"e >> +energy_{charge}", f"wt {cut}", "hist goff")
    else:
        t.Draw(f"e >> energy_{charge}({ebins[0]}, {ebins[1]}, {ebins[2]})", f"wt {cut}", "hist goff")
        he = R.gDirectory.Get(f"energy_{charge}")
    he.SetName(f"energy_{charge}")
    he.SetTitle(f"energy_{charge}")

    print (he.GetEntries(), he.GetSumOfWeights())

    t.Draw(f"th >> theta_{charge}({thbins[0]}, {thbins[1]}, {thbins[2]})", f"wt {cut}", "hist goff")
    hth = R.gDirectory.Get(f"theta_{charge}")
    hth.SetName(f"theta_{charge}")
    hth.SetTitle(f"theta_{charge}")

    heth = None
    if ebins is None:
        ebins2 = array("f", [0, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 3000])
        he = R.TH2F(f"thetaenergy_{charge}", f"thetaenergy_{charge}", len(ebins2)-1, ebins2, thbins[0], thbins[1], thbins[2])
        t.Draw(f"e >> +thetaenergy_{charge}", f"wt {cut}", "hist goff")
    else:
        t.Draw(f"th:e >> thetaenergy_{charge}({ebins[0]}, {ebins[1]}, {ebins[2]}, {thbins[0]}, {thbins[1]}, {thbins[2]})", f"wt {cut}", "hist goff")
        heth = R.gDirectory.Get(f"thetaenergy_{charge}")
    heth.SetName(f"thetaenergy_{charge}")
    heth.SetTitle(f"thetaenergy_{charge}")

    return he, hth, heth

def create(oname = "fluka_spectra.root", ebins = (100, 106, 4500e3), thbins = (140,0,0.14), theta = (), radii = ()):
    #fneg = R.TFile.Open("/eos/experiment/faser/sim/mc22/fluka/210001/sim/s0009/FaserMC-MC22_Fluka_unit30_Nm_71m_m3750_v3-210001-00000-s0009-HITS.root")
    tneg = R.TChain("CollectionTree")
    tneg.Add("/eos/experiment/faser/sim/mc22/fluka/210001/rdo/s0009/FaserMC-MC22_Fluka*210001*-RDO.root")    
    #"/eos/experiment/faser/gen/FlukaMDC/EVNT/Fluka_unit30_Nm_71m_m3750.evgen.pool.v3.root")
    #tneg = fneg.Get("CollectionTree")
    histsneg = main(tneg, "minus", ebins = ebins, thbins = thbins, theta = theta, radii = radii)

    tpos = R.TChain("CollectionTree")
    tpos.Add("/eos/experiment/faser/sim/mc22/fluka/210002/rdo/s0009/FaserMC-MC22_Fluka*-210002*-RDO.root")
    #"/eos/experiment/faser/gen/FlukaMDC/EVNT/Fluka_unit30_Pm_71m_m3750.evgen.pool.v3.root")
    #tpos = fpos.Get("CollectionTree")
    histspos = main(tpos, "plus", ebins = ebins, thbins = thbins, theta = theta, radii = radii)

    fout = R.TFile.Open(oname, "RECREATE")
    for h in histspos:
        h.Write()
    for h in histsneg:
        h.Write()
    fout.Close()        

def check(name, charge):

    c = R.TCanvas()
    R.gPad.SetLogy(True)

    forig = R.TFile.Open("fluka_spectra.root")
    horig = forig.Get(f"{name}_{charge}")

    horig.SetLineColor(1)
    horig.DrawNormalized("hist")
    horig.SetDirectory(0)

    forig.Close()

    fnew =  R.TFile.Open("FaserMC-PG_logE-999999-00000-HITS.root")
    tnew = fnew.Get("CollectionTree")

    tnew.SetAlias("px", "McEventCollection_p5_BeamTruthEvent.m_genParticles.m_px")
    tnew.SetAlias("py", "McEventCollection_p5_BeamTruthEvent.m_genParticles.m_py")
    tnew.SetAlias("pz", "McEventCollection_p5_BeamTruthEvent.m_genParticles.m_pz")
    tnew.SetAlias("pt", "sqrt(px*px + py*py)")
    tnew.SetAlias("p", "sqrt(pz*pz + pt*pt)")     
    tnew.SetAlias("m", "McEventCollection_p5_BeamTruthEvent.m_genParticles.m_m")
    tnew.SetAlias("e", "sqrt(px*px + py*py + pz*pz + m*m)")
    tnew.SetAlias("th", "asin(pt/p)")
    tnew.SetAlias("ph", "acos(px/pt)")
    tnew.SetAlias("wt", "McEventCollection_p5_BeamTruthEvent.m_genEvents.m_weights")

    if name == "energy":
        tnew.Draw(f"e >> energy_{charge}(100, 0, 4500e3)", "wt", "hist goff")
    elif name == "theta":
        tnew.Draw(f"th >> theta_{charge}(140,0,0.14)", "wt", "hist goff")
        
    h2 = R.gDirectory.Get(f"{name}_{charge}")
    h2.SetName(f"{name}_{charge}")
    h2.SetTitle(f"{name}_{charge}")
    h2.SetLineColor(2)
    h2.DrawNormalized("histsame")

    c.Print(f"fluka_check_{name}_{charge}.eps")
    return

def comp(name, charge):

    fnames = {
        "<100mm" : "fluka_spectra_trunc_100.root",
        "<200mm" : "fluka_spectra_trunc_200.root",
        "<300mm" : "fluka_spectra_trunc_300.root",
        "<400mm" : "fluka_spectra_trunc_400.root",
        "<500mm" : "fluka_spectra_trunc_500.root",
        "None" : "fluka_spectra_trunc.root"
        }

    
    c = R.TCanvas()
    c._objs = []
    leg = set_legend("TR")    
    
    for i, (n, fn) in enumerate(fnames.items()):
        f = R.TFile.Open(fn)
        h = f.Get(f"{name}_{charge}")
        h.SetDirectory(0)
        c._objs.append(h)
        h.SetLineColor(i+1)
        h.DrawNormalized("histsame" if i else "hist")
        leg.AddEntry(h, n, "l")


    leg.Draw()
    c.Print(f"fluka_comp_{name}_{charge}.eps")
        
        

    return

if __name__ == "__main__":
    R.gROOT.SetBatch(True)
    #R.gStyle.SetOptStat(0)

    create("fluka_geom_large_radius_r9_r25.root", radii = (90, 250), ebins = (50, 106, 4500e3), thbins = (60,0,0.06))
    
    #create("fluka_geom_with_radius.root", ebins = (50, 106, 100e3), thbins = (15,0.015,0.060), theta = (0.015, 0.06), radii = (150, 300))

    #create("fluka_spectra_5mrad.root", ebins = (50, 106, 4500e3) , thbins = (50,0,0.005), theta = (0, 0.005) )
    #create("fluka_spectra_allmrad.root", ebins = (50, 106, 4500e3) , thbins = (100,0,0.02), theta = (0, 0.02) )    

    #create("fluka_spectra_all.root")
    #create("fluka_spectra_trunc.root", thbins = (100,0,0.005), theta = (0, 0.005) )
    
    #create("fluka_spectra_geom.root", ebins = (50, 106, 100e3), thbins = (30,0.030,0.060), theta = (0.03, 0.06)) #, radii = (100, 130))

    #create("fluka_spectra_geom_15mrad.root", ebins = (50, 106, 100e3), thbins = (15,0.015,0.030), theta = (0.015, 0.03)) #, radii = (100, 130))    

    #create("fluka_spectra_geom.root", ebins = (50, 106, 100e3), thbins = (30,0.030,0.060), theta = (0.03, 0.06), radii = (100, 130))

    #create("fluka_spectra_geom.root", ebins = (50, 106, 100e3), thbins = (15, 0.015, 0.030), theta = (0.03, 0.06), radii = (100, 130))

    #for r in [100, 200, 300, 400, 500]:
    #    create(f"fluka_spectra_trunc_{r}.root", thbins = (100,0,0.005), theta = (0, 0.005), radii = (0, r))

    #r = 100
    #create(f"fluka_spectra_trunc_{r}.root", ebins = None, thbins = (25,0,0.005), theta = (0, 0.005), radii = (0, r))
    
    #check("energy", "minus")
    #check("theta", "minus")

    #comp("energy", "minus")
    #comp("energy", "plus")
    #comp("theta", "minus")
    #comp("theta", "plus")            


# all neg 411697916.3642578 in 1/fb
# all pos 316713577.5716553 in 1/fb

# trunc neg 277698985.6767578 in 1/fb
# trunc pos 204611576.10290527 in 1/fb

# geom neg: 1164368.7707519531 in 30/fb
# geom pos:  303460.0591278076 in 30/fb




# 54536.0 15060706.792869568
# * ((r > 90 && r < 200))
# 16123.0 8907735.584281921
