import { common } from "@mui/material/colors";
import _, { concat } from "lodash";
import { findMissingDates } from "../utils/Date";
import { Direction } from "./Direction";
import { DEFAULT_BIAS, DEFAULT_MoE, type PipelineMetricMap } from "./Pipeline";

export enum CountType {
    LowerBound = "Lower Bound",
    Predicted = "Predicted",
    UpperBound = "Upper Bound",
}

export interface AdjustedCounts {
    [Direction.Upstream]: Record<CountType, number>;
    [Direction.Downstream]: Record<CountType, number>;
    [Direction.Net]: Record<CountType, number>;
}

export interface BoundedCount {
    timestamp: Date;
    pipelineIds: string[] | null;
    value: AdjustedCounts | null;
    imputed?: boolean;
}

/** A convenience class used to
 * calculate the confidence intervals
 * for the direction of a given pipeline.
 * The reuslts values may be multiplied as a
 * modifier to a predicted count in order
 * to perform bias correction or to
 * compute the boundary of the confidence interval.
 */
class CountCalculator {
    private map: PipelineMetricMap;

    constructor(map: PipelineMetricMap) {
        this.map = map;
    }

    /** calculates the modifiers to apply to  */
    calculateConfidenceInterval(
        pipelineId: string,
        direction: Direction,
    ): {
        [CountType.LowerBound]: number;
        [CountType.Predicted]: number;
        [CountType.UpperBound]: number;
    } {
        const bias = this.map.getBias(pipelineId, direction) ?? DEFAULT_BIAS;
        const moe = this.map.getMarginOfError(pipelineId, direction) ?? DEFAULT_MoE;

        return {
            [CountType.LowerBound]: bias - moe,
            [CountType.Predicted]: bias,
            [CountType.UpperBound]: bias + moe,
        };
    }
}

interface RawCount {
    startTimestamp: string;
    pipelineID: string;
    upstreamCount: number | null;
    downstreamCount: number | null;
}

/** Performs bias correction using the scaling technique.
 * @metrics is provided to get the function access to each pipeline's
 * aggregated proportional bias and the calculated margin of error (_ci95)
 * in order to calculate confidence intervals.
 */
export const scalingBiasCorrection = (
    detections: RawCount[],
    metrics: PipelineMetricMap,
): BoundedCount[] => {
    const calcr = new CountCalculator(metrics);

    const corrected = detections.map((detection) => {
        return ((detection: RawCount): BoundedCount => {
            const upstreamCnfIntrvls = calcr.calculateConfidenceInterval(
                detection.pipelineID,
                Direction.Upstream,
            );

            const upstreamLower = Math.floor(
                (detection.upstreamCount ?? 0) *
                    upstreamCnfIntrvls[CountType.LowerBound],
            );

            const upstreamAdjusted = Math.floor(
                (detection.upstreamCount ?? 0) *
                    upstreamCnfIntrvls[CountType.Predicted],
            );

            const upstreamUpper = Math.floor(
                (detection.upstreamCount ?? 0) *
                    upstreamCnfIntrvls[CountType.UpperBound],
            );

            const downstreamCnfIntrvls = calcr.calculateConfidenceInterval(
                detection.pipelineID,
                Direction.Downstream,
            );

            const downstreamLower = Math.floor(
                (detection.downstreamCount ?? 0) *
                    downstreamCnfIntrvls[CountType.LowerBound],
            );

            const downstreamAdjusted = Math.floor(
                (detection.downstreamCount ?? 0) *
                    downstreamCnfIntrvls[CountType.Predicted],
            );

            const downstreamUpper = Math.floor(
                (detection.downstreamCount ?? 0) *
                    downstreamCnfIntrvls[CountType.UpperBound],
            );

            return {
                timestamp: new Date(detection.startTimestamp),
                pipelineIds: [detection.pipelineID],
                value: {
                    [Direction.Upstream]: {
                        [CountType.LowerBound]: upstreamLower,
                        [CountType.Predicted]: upstreamAdjusted,
                        [CountType.UpperBound]: upstreamUpper,
                    },
                    [Direction.Downstream]: {
                        [CountType.LowerBound]: downstreamLower,
                        [CountType.Predicted]: downstreamAdjusted,
                        [CountType.UpperBound]: downstreamUpper,
                    },
                    [Direction.Net]: {
                        [CountType.LowerBound]: upstreamLower - downstreamUpper,
                        [CountType.Predicted]:
                            upstreamAdjusted - downstreamAdjusted,
                        [CountType.UpperBound]: upstreamUpper - downstreamLower,
                    },
                },
            };
        })(detection);
    });

    return _(corrected)
        .groupBy((count) => count.timestamp)
        .map((dateCounts) => {
            const cumulative = cumulativeSum(dateCounts).at(-1);
            if (cumulative !== undefined) {
                cumulative.pipelineIds = _.flatMap(dateCounts, (count) =>
                    _.compact(count.pipelineIds),
                );
            }
            return cumulative;
        })
        .compact()
        .value() as BoundedCount[];
};

export function cumulativeSum(boundedCounts: BoundedCount[]): BoundedCount[] {
    let cumulative = {
        [Direction.Upstream]: {
            [CountType.LowerBound]: 0,
            [CountType.Predicted]: 0,
            [CountType.UpperBound]: 0,
        },
        [Direction.Downstream]: {
            [CountType.LowerBound]: 0,
            [CountType.Predicted]: 0,
            [CountType.UpperBound]: 0,
        },
        [Direction.Net]: {
            [CountType.LowerBound]: 0,
            [CountType.Predicted]: 0,
            [CountType.UpperBound]: 0,
        },
    };

    return boundedCounts.map((boundedCount): BoundedCount => {
        if (!boundedCount.value) {
            return {
                timestamp: boundedCount.timestamp,
                pipelineIds: boundedCount.pipelineIds,
                value: cumulative,
                imputed: true,
            };
        }

        cumulative = {
            [Direction.Upstream]: {
                [CountType.LowerBound]: Math.floor(
                    cumulative.Upstream[CountType.LowerBound] +
                        boundedCount.value.Upstream![CountType.LowerBound]!,
                ),
                [CountType.Predicted]: Math.floor(
                    cumulative.Upstream.Predicted +
                        boundedCount.value.Upstream!.Predicted!,
                ),
                [CountType.UpperBound]: Math.floor(
                    cumulative.Upstream[CountType.UpperBound] +
                        boundedCount.value.Upstream![CountType.UpperBound]!,
                ),
            },
            [Direction.Downstream]: {
                [CountType.LowerBound]: Math.floor(
                    cumulative.Downstream[CountType.LowerBound] +
                        boundedCount.value.Downstream![CountType.LowerBound]!,
                ),
                [CountType.Predicted]: Math.floor(
                    cumulative.Downstream.Predicted +
                        boundedCount.value.Downstream!.Predicted!,
                ),
                [CountType.UpperBound]: Math.floor(
                    cumulative.Downstream[CountType.UpperBound] +
                        boundedCount.value.Downstream![CountType.UpperBound]!,
                ),
            },
            [Direction.Net]: {
                [CountType.LowerBound]: Math.floor(
                    cumulative.Net[CountType.LowerBound] +
                        boundedCount.value.Net![CountType.LowerBound]!,
                ),
                [CountType.Predicted]: Math.floor(
                    cumulative.Net[CountType.Predicted] +
                        boundedCount.value.Net![CountType.Predicted]!,
                ),
                [CountType.UpperBound]: Math.floor(
                    cumulative.Net[CountType.UpperBound] +
                        boundedCount.value.Net![CountType.UpperBound]!,
                ),
            },
        };

        return {
            timestamp: boundedCount.timestamp,
            pipelineIds: boundedCount.pipelineIds,
            value: cumulative,
        };
    });
}

export function alignCounts(
    boundedCounts: BoundedCount[][],
    start: Date,
    end: Date,
): BoundedCount[][] {
    const seasonHours = _(boundedCounts)
        .flatten()
        .unionBy(({ timestamp }) => timestamp.toISOString())
        .map(({ timestamp }) => timestamp)
        // preliminary sort for gap detection
        .sortBy((timestamp) => timestamp)
        .thru((commonHours) => {
            return _.concat(
                commonHours,
                findMissingDates(
                    commonHours,
                    start,
                    end,
                    // ms (1000) * seconds (60) * minutes (60) = hour
                    1000 * 3600,
                ),
            );
        })
        // post sort to handle hours which were missing.
        .sortBy((timestamp) => timestamp)
        .value();

    return boundedCounts.map((series) => {
        const countMap = new Map(
            series.map((item) => [item.timestamp.toISOString(), item]),
        );
        return seasonHours.map((timestamp) => {
            const existingItem = countMap.get(timestamp.toISOString());
            return (
                existingItem || {
                    pipelineIds: null,
                    timestamp: timestamp,
                    value: null,
                    imputed: true,
                }
            );
        });
    });
}

export function isCIEligible(
    detections: RawCount[],
    metrics: PipelineMetricMap,
) : boolean {
     // CIs must be disabled if there are no pipeline metrics that generated the counts
     if (metrics.isEmpty()) {
        return false
     } else {
        return detections.some((detection) => isDetectionEligible(detection, metrics));
     }
     
}

function isDetectionEligible(detection: RawCount, metrics: PipelineMetricMap): boolean {
    const isUpstreamEligible = isEligible(metrics, detection.pipelineID, Direction.Upstream);
    const isDownstreamEligible = isEligible(metrics, detection.pipelineID, Direction.Downstream);
    // if either directionalities are ineligible to compute CIs, then return false
    return (isUpstreamEligible && isDownstreamEligible);
}

function isEligible(metrics: PipelineMetricMap, pipelineID: string, direction: Direction): boolean {
    return metrics.getMarginOfError(pipelineID, direction) !== undefined && metrics.getBias(pipelineID, direction) !== undefined;
}