import _ from "lodash";
import { findMissingDates } from "../utils/Date";
import { Direction } from "./Direction";
import { DEFAULT_BIAS, DEFAULT_MoE, type PipelineMetricMap } from "./Pipeline";

export enum CountType {
    UpperBound = "UpperBound",
    Adjusted = "Adjusted",
    LowerBound = "LowerBound",
    Predicted = "Predicted",
}

export function formatIntervalBound(value: CountType): string {
    return value
        .toString()
        .replace(/([A-Z])/g, " $1")
        .trim();
}

export type TypedCount = {
    [key in CountType]: number | undefined | null;
}

export type DirectionalCount = {
    [key in Direction]: TypedCount
};

export type CountValue = {
    instant?: DirectionalCount
    cumulative?: DirectionalCount
};

export type Count = {
    timestamp: Date;
    pipelineIds: string[] | null;
    imputed?: boolean;
} & CountValue;

/** A convenience class used to
 * calculate the confidence intervals
 * for the direction of a given pipeline.
 * The reuslts values may be multiplied as a
 * modifier to a predicted count in order
 * to perform bias correction or to
 * compute the boundary of the confidence interval.
 */
class CountCalculator {
    private map: PipelineMetricMap;

    constructor(map: PipelineMetricMap) {
        this.map = map;
    }

    /** calculates the modifiers to apply to  */
    calculateConfidenceInterval(
        pipelineId: string,
        direction: Direction,
    ): {
        [CountType.LowerBound]: number;
        [CountType.Adjusted]: number;
        [CountType.UpperBound]: number;
    } {
        const bias = this.map.getBias(pipelineId, direction) ?? DEFAULT_BIAS;
        const moe =
            this.map.getMarginOfError(pipelineId, direction) ?? DEFAULT_MoE;

        return {
            [CountType.LowerBound]: bias - moe,
            [CountType.Adjusted]: bias,
            [CountType.UpperBound]: bias + moe,
        };
    }
}

interface PredictedCount {
    startTimestamp: string;
    pipelineID: string;
    upstreamCount: number | null;
    downstreamCount: number | null;
}

/** Performs bias correction using the scaling technique.
 * @metrics is provided to get the function access to each pipeline's
 * aggregated proportional bias and the calculated margin of error (_ci95)
 * in order to calculate confidence intervals.
 */
export const scalingBiasCorrection = (
    detections: PredictedCount[],
    metrics: PipelineMetricMap,
): Count[] => {
    const calcr = new CountCalculator(metrics);

    const corrected = detections.map((detection) => {
        return ((detection: PredictedCount): Count => {
            const upstreamCnfIntrvls = calcr.calculateConfidenceInterval(
                detection.pipelineID,
                Direction.Upstream,
            );

            const rawUpstream = detection.upstreamCount ?? 0;
            const upstreamLower = Math.floor(
                rawUpstream * upstreamCnfIntrvls[CountType.LowerBound],
            );

            const upstreamAdjusted = Math.floor(
                (detection.upstreamCount ?? 0) *
                    upstreamCnfIntrvls[CountType.Adjusted],
            );

            const upstreamUpper = Math.floor(
                (detection.upstreamCount ?? 0) *
                    upstreamCnfIntrvls[CountType.UpperBound],
            );

            const downstreamCnfIntrvls = calcr.calculateConfidenceInterval(
                detection.pipelineID,
                Direction.Downstream,
            );

            const downstreamLower = Math.floor(
                (detection.downstreamCount ?? 0) *
                    downstreamCnfIntrvls[CountType.LowerBound],
            );

            const rawDownstream = detection.downstreamCount ?? 0;
            const downstreamAdjusted = Math.floor(
                rawDownstream * downstreamCnfIntrvls[CountType.Adjusted],
            );

            const downstreamUpper = Math.floor(
                (detection.downstreamCount ?? 0) *
                    downstreamCnfIntrvls[CountType.UpperBound],
            );

            return {
                timestamp: new Date(detection.startTimestamp),
                pipelineIds: [detection.pipelineID],
                instant: {
                    [Direction.Upstream]: {
                        [CountType.Predicted]: rawUpstream,
                        [CountType.LowerBound]: upstreamLower,
                        [CountType.Adjusted]: upstreamAdjusted,
                        [CountType.UpperBound]: upstreamUpper,
                    },
                    [Direction.Downstream]: {
                        [CountType.Predicted]: rawDownstream,
                        [CountType.LowerBound]: downstreamLower,
                        [CountType.Adjusted]: downstreamAdjusted,
                        [CountType.UpperBound]: downstreamUpper,
                    },
                    [Direction.Net]: {
                        [CountType.Predicted]: rawUpstream - rawDownstream,
                        [CountType.LowerBound]: upstreamLower - downstreamUpper,
                        [CountType.Adjusted]:
                            upstreamAdjusted - downstreamAdjusted,
                        [CountType.UpperBound]: upstreamUpper - downstreamLower,
                    },
                },
            };
        })(detection);
    });

    return _(corrected)
        .groupBy((count) => count.timestamp)
        .map((dateCounts) => {
            const cumulative = cumulativeSum(dateCounts).at(-1);
            if (cumulative !== undefined) {
                cumulative.pipelineIds = _.flatMap(dateCounts, (count) =>
                    _.compact(count.pipelineIds),
                );
            }
            return cumulative;
        })
        .compact()
        .value() as Count[];
};

export function cumulativeSum(counts: Count[]): Count[] {
    if (!counts) {
        return [];
    }
    let agg = {
        cumulative: {
            Upstream: {
                [CountType.Predicted]: 0,
                [CountType.LowerBound]: 0,
                [CountType.Adjusted]: 0,
                [CountType.UpperBound]: 0,
            },
            Downstream: {
                [CountType.Predicted]: 0,
                [CountType.LowerBound]: 0,
                [CountType.Adjusted]: 0,
                [CountType.UpperBound]: 0,
            },
            Net: {
                [CountType.Predicted]: 0,
                [CountType.LowerBound]: 0,
                [CountType.Adjusted]: 0,
                [CountType.UpperBound]: 0,
            },
        },
    };

    return counts.map((count): Count => {
        if (count.imputed) {
            return {
                timestamp: count.timestamp,
                pipelineIds: count.pipelineIds,
                imputed: count.imputed,
                cumulative: agg.cumulative,
            };
        }

        agg = {
            cumulative: {
                Upstream: {
                    Predicted:
                        agg.cumulative!.Upstream.Predicted! +
                        count.instant!.Upstream.Predicted!,
                    LowerBound:
                        agg.cumulative!.Upstream.LowerBound! +
                        count.instant!.Upstream.LowerBound!,
                    Adjusted:
                        agg.cumulative!.Upstream.Adjusted! +
                        count!.instant!.Upstream.Adjusted!,
                    UpperBound:
                        agg.cumulative!.Upstream.UpperBound! +
                        count.instant!.Upstream.UpperBound!,
                },
                Downstream: {
                    Predicted:
                        agg.cumulative!.Downstream.Predicted! +
                        count.instant!.Downstream.Predicted!,
                    LowerBound:
                        agg.cumulative!.Downstream.LowerBound! +
                        count.instant!.Downstream.LowerBound!,
                    Adjusted:
                        agg.cumulative!.Downstream.Adjusted! +
                        count!.instant!.Downstream.Adjusted!,
                    UpperBound:
                        agg.cumulative!.Downstream.UpperBound! +
                        count.instant!.Downstream.UpperBound!,
                },
                Net: {
                    Predicted:
                        agg.cumulative!.Net.Predicted! +
                        count.instant!.Net.Predicted!,
                    LowerBound:
                        agg.cumulative!.Net.LowerBound! +
                        count.instant!.Net.LowerBound!,
                    Adjusted:
                        agg.cumulative!.Net.Adjusted! +
                        count.instant!.Net.Adjusted!,
                    UpperBound:
                        agg.cumulative!.Net.UpperBound! +
                        count.instant!.Net.UpperBound!,
                },
            },
        };

        return {
            timestamp: count.timestamp,
            pipelineIds: count.pipelineIds,
            imputed: count.imputed,
            cumulative: agg.cumulative,
            instant: count.instant,
        };
    });
}

export function alignCounts(
    counts: Count[][],
    start: Date,
    end: Date,
): Count[][] {
    const seasonHours = _(counts)
        .flatten()
        .unionBy(({ timestamp }) => timestamp.toISOString())
        .map(({ timestamp }) => timestamp)
        // preliminary sort for gap detection
        .sortBy((timestamp) => timestamp)
        .thru((commonHours) => {
            return _.concat(
                commonHours,
                findMissingDates(
                    commonHours,
                    start,
                    end,
                    // ms (1000) * seconds (60) * minutes (60) = hour
                    1000 * 3600,
                ),
            );
        })
        // post sort to handle hours which were missing.
        .sortBy((timestamp) => timestamp)
        .value();

    return counts.map((series) => {
        const countMap = new Map(
            series.map((item) => [item.timestamp.toISOString(), item]),
        );
        return seasonHours.map((timestamp) => {
            const existingItem = countMap.get(timestamp.toISOString());
            return (
                existingItem || {
                    pipelineIds: null,
                    timestamp: timestamp,
                    imputed: true,
                }
            );
        });
    });
}

export function isCIEligible(
    detections: PredictedCount[],
    metrics: PipelineMetricMap,
): boolean {
    // CIs must be disabled if there are no pipeline metrics that generated the counts
    if (metrics.isEmpty()) {
        return false;
    } else {
        return detections.some((detection) =>
            isDetectionEligible(detection, metrics),
        );
    }
}

function isDetectionEligible(
    detection: PredictedCount,
    metrics: PipelineMetricMap,
): boolean {
    const isUpstreamEligible = isEligible(
        metrics,
        detection.pipelineID,
        Direction.Upstream,
    );
    const isDownstreamEligible = isEligible(
        metrics,
        detection.pipelineID,
        Direction.Downstream,
    );
    // if either directionalities are ineligible to compute CIs, then return false
    return isUpstreamEligible && isDownstreamEligible;
}

function isEligible(
    metrics: PipelineMetricMap,
    pipelineID: string,
    direction: Direction,
): boolean {
    return (
        metrics.getMarginOfError(pipelineID, direction) !== undefined &&
        metrics.getBias(pipelineID, direction) !== undefined
    );
}
