import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy

default_cols = ["run", "event", "type", "gen", "E", "w", "x", "y", "cosX", "cosY"]
mmu = 105.66 / 1000

def extract_data(fname, columns = default_cols, fiducial = False, x = 100, y = 100, n = None):

    data = []

    with open(fname) as f:
        i = 0
        for l in f:
            if n and i > n: 
                break

            entry = dict(zip(columns, l.strip("\n").split()))
                         
            for i,c in enumerate(columns):
                if i < 4:
                    entry[c] = int(entry[c])
                else:
                    entry[c] = float(entry[c])

                    #print(entry["E"], mmu, entry["cosY"], 
            entry["Pz"] = np.sqrt(entry["E"]**2 - mmu**2) * np.cos(np.abs(angle(entry["cosY"])))

            if not fiducial or (abs(entry["x"]) < x and abs(entry["y"]) < y):
                data.append(entry)
                i = i+1

    return data

def flatten(data, key):
    return [e[key] for e in data]

def angle(cosTheta):
    "Convert cos(theta) wrt x or y axis to theta wrt z axis"
    return np.pi/2 - np.arccos(cosTheta)

def plot(data = None, data_pos = None, data_neg = None, key = "E", 
         xlabel = "x", ylabel = "y", name = "E.png", bins = None, 
         logx = False, logy = True, density = True):

    plt.figure()
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)

    if logx:
        plt.xscale('log')

    if logy:
        plt.yscale('log')

    if data:
        plt.hist(flatten(data, key), weights = flatten(data, "w"), 
                 bins = bins, label = "all", histtype = "step", color = "black", fill = False, density = density)

    if data_pos:
        plt.hist(flatten(data_pos, key), weights = flatten(data_pos, "w"), 
                 bins = bins, label = "pos", histtype = "step", color = "b", fill = False, density = density)

    if data_neg:
        plt.hist(flatten(data_neg, key), weights = flatten(data_neg, "w"), 
                 bins = bins, label = "neg", histtype = "step", color = "r", fill = False, density = density)

    plt.legend()
    plt.savefig("Fluka/" + name + ".png")

    return

if __name__ == "__main__":

    path = "/eos/experiment/faser/gen/MC22/Fluka"

    #path = "/user/gwilliam"
    data56_pos = extract_data(f"{path}/unit30_Pm", fiducial = True)      
    data56_neg = extract_data(f"{path}/unit30_Nm", fiducial = True)      
    data56 = deepcopy(data56_pos)
    data56.extend(data56_neg)

    #path = "/bundle/data/FASER/Carl_Output"
    data7 = extract_data(f"{path}/LHC_-160urad_magfield_2022TCL6_muons_rock_2e8pr", fiducial = True)#, n = 200000)
    data7_pos = [d for d in data7 if d["type"] == 10]
    data7_neg = [d for d in data7 if d["type"] == 11]

    data12 = extract_data(f"{path}/ALL_lhc_ir1_coll_2023_exp001_fort.30", fiducial = True)#, n = 200000)
    data12_pos = [d for d in data12 if d["type"] == 10]
    data12_neg = [d for d in data12 if d["type"] == 11]

    ebins = np.linspace(0, 5000, 50)
    plot(data56, data56_pos, data56_neg, "E", "Energy [GeV]", "Events", name = "E_2100056", bins = ebins)
    plot(data7, data7_pos, data7_neg, "E", "Energy [GeV]", "Events", name = "E_210007", bins = ebins)
    plot(data12, data12_pos, data12_neg, "E", "Energy [GeV]", "Events", name = "E_210012", bins = ebins)

    #ebins = np.linspace(0, 5000, 50)
    pzbins = [0, 10, 25, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450,500, 600, 700, 800, 900, 1000, 7000]
    plot(data56, data56_pos, data56_neg, "Pz", "Pz [GeV]", "Events", name = "Pz_2100056", bins = pzbins, logx = True)
    plot(data7, data7_pos, data7_neg, "Pz", "Pz [GeV]", "Events", name = "Pz_210007", bins = pzbins, logx = True)
    plot(data12, data12_pos, data12_neg, "Pz", "Pz [GeV]", "Events", name = "Pz_210012", bins = pzbins, logx = True)

