import { Direction } from "./Direction";
import type { PipelineMetricMap } from "./Pipeline";

export enum CountType {
    LowerBound = "Lower Bound",
    Predicted = "Predicted",
    UpperBound = "Upper Bound",
}

export interface AdjustedCounts {
    [Direction.Upstream]: Record<CountType, number>;
    [Direction.Downstream]: Record<CountType, number>;
    [Direction.Net]: Record<CountType, number>;
}

export interface BoundedCount {
    timestamp: Date;
    pipelineId: string;
    value: AdjustedCounts;
}

/** A convenience class used to
 * calculate the confidence intervals
 * for the direction of a given pipeline.
 * The reuslts values may be multiplied as a
 * modifier to a predicted count in order
 * to perform bias correction or to
 * compute the boundary of the confidence interval.
 */
class CountCalculator {
    private map: PipelineMetricMap;

    constructor(map: PipelineMetricMap) {
        this.map = map;
    }

    /** calculates the modifiers to apply to  */
    calculateConfidenceInterval(
        pipelineId: string,
        direction: Direction,
    ): {
        [CountType.LowerBound]: number;
        [CountType.Predicted]: number;
        [CountType.UpperBound]: number;
    } {
        const bias = this.map.getBias(pipelineId, direction);
        const moe = this.map.getMarginOfError(pipelineId, direction);

        return {
            [CountType.LowerBound]: bias - moe,
            [CountType.Predicted]: bias,
            [CountType.UpperBound]: bias + moe,
        };
    }
}

interface RawCount {
    startTimestamp: Date;
    pipelineID: string;
    upstreamCount: number | null;
    downstreamCount: number | null;
}

/** Performs bias correction using the scaling technique.
 * @metrics is provided to get the function access to each pipeline's
 * aggregated proportional bias and the calculated margin of error (_ci95)
 * in order to calculate confidence intervals.
 */
export const scalingBiasCorrection = (
    detections: RawCount[],
    metrics: PipelineMetricMap,
): BoundedCount[] => {
    const calcr = new CountCalculator(metrics);
    return detections.map((detection) => {
        return ((detection: RawCount): BoundedCount => {
            const upstreamCnfIntrvls = calcr.calculateConfidenceInterval(
                detection.pipelineID,
                Direction.Upstream,
            );

            const upstreamLower = Math.floor(
                (detection.upstreamCount ?? 0) *
                    upstreamCnfIntrvls[CountType.LowerBound],
            );

            const upstreamAdjusted = Math.floor(
                (detection.upstreamCount ?? 0) *
                    upstreamCnfIntrvls[CountType.Predicted],
            );

            const upstreamUpper = Math.floor(
                (detection.upstreamCount ?? 0) *
                    upstreamCnfIntrvls[CountType.UpperBound],
            );

            const downstreamCnfIntrvls = calcr.calculateConfidenceInterval(
                detection.pipelineID,
                Direction.Downstream,
            );
            
            const downstreamLower = Math.floor(
                (detection.downstreamCount ?? 0) *
                    downstreamCnfIntrvls[CountType.LowerBound],
            );

            const downstreamAdjusted = Math.floor(
                (detection.downstreamCount ?? 0) *
                    downstreamCnfIntrvls[CountType.Predicted],
            );

            const downstreamUpper = Math.floor(
                (detection.downstreamCount ?? 0) *
                    downstreamCnfIntrvls[CountType.UpperBound],
            );

            return {
                timestamp: detection.startTimestamp,
                pipelineId: detection.pipelineID,
                value: {
                    [Direction.Upstream]: {
                        [CountType.LowerBound]: upstreamLower,
                        [CountType.Predicted]: upstreamAdjusted,
                        [CountType.UpperBound]: upstreamUpper,
                    },
                    [Direction.Downstream]: {
                        [CountType.LowerBound]: downstreamLower,
                        [CountType.Predicted]: downstreamAdjusted,
                        [CountType.UpperBound]: downstreamUpper,
                    },
                    [Direction.Net]: {
                        [CountType.LowerBound]: upstreamLower - downstreamUpper,
                        [CountType.Predicted]:
                            upstreamAdjusted - downstreamAdjusted,
                        [CountType.UpperBound]: upstreamUpper - downstreamLower,
                    },
                },
            };
        })(detection);
    });
};

export function cumulativeSum(boundedCounts: BoundedCount[]): BoundedCount[] {
    let cumulative = {
        [Direction.Upstream]: {
            [CountType.LowerBound]: 0,
            [CountType.Predicted]: 0,
            [CountType.UpperBound]: 0,
        },
        [Direction.Downstream]: {
            [CountType.LowerBound]: 0,
            [CountType.Predicted]: 0,
            [CountType.UpperBound]: 0,
        },
        [Direction.Net]: {
            [CountType.LowerBound]: 0,
            [CountType.Predicted]: 0,
            [CountType.UpperBound]: 0,
        },
    };

    return boundedCounts.map((boundedCount): BoundedCount => {
        cumulative = {
            [Direction.Upstream]: {
                [CountType.LowerBound]: Math.floor(
                    cumulative.Upstream[CountType.LowerBound] +
                        boundedCount.value.Upstream[CountType.LowerBound],
                ),
                [CountType.Predicted]: Math.floor(
                    cumulative.Upstream.Predicted +
                        boundedCount.value.Upstream.Predicted,
                ),
                [CountType.UpperBound]: Math.floor(
                    cumulative.Upstream[CountType.UpperBound] +
                        boundedCount.value.Upstream[CountType.UpperBound],
                ),
            },
            [Direction.Downstream]: {
                [CountType.LowerBound]: Math.floor(
                    cumulative.Downstream[CountType.LowerBound] +
                        boundedCount.value.Downstream[CountType.LowerBound],
                ),
                [CountType.Predicted]: Math.floor(
                    cumulative.Downstream.Predicted +
                        boundedCount.value.Downstream.Predicted,
                ),
                [CountType.UpperBound]: Math.floor(
                    cumulative.Downstream[CountType.UpperBound] +
                        boundedCount.value.Downstream[CountType.UpperBound],
                ),
            },
            [Direction.Net]: {
                [CountType.LowerBound]:
                    cumulative.Net[CountType.LowerBound] +
                    boundedCount.value.Net[CountType.LowerBound],
                [CountType.Predicted]: Math.floor(
                    cumulative.Net[CountType.Predicted] +
                        boundedCount.value.Net[CountType.Predicted],
                ),
                [CountType.UpperBound]:
                    cumulative.Net[CountType.UpperBound] +
                    boundedCount.value.Net[CountType.UpperBound],
            },
        };

        return {
            timestamp: boundedCount.timestamp,
            pipelineId: boundedCount.pipelineId,
            value: cumulative,
        };
    });
}
