import { Ingredient, Outcome,
         PerIngredient, PerCategory, SegmentInfo, CategoryCounts } from '../sharedTypes';
import { SimulatorState } from '../store';


export function valuesToNames(ingredient: Ingredient) {
  const valueToName: { [catValue: string]: string} = {};
  for (const cat of ingredient.categories) {
    valueToName[cat.value] = cat.display_name;
  }
  return valueToName;
}

// A number for each category, summing to 1.
export type Distribution = PerCategory<number>;
// How the population of each category will be distributed across other categories.
export type MovementMatrix = { [sourceCat: string]: Distribution };

// Returns the movement matrix for the best outcome based on the whatifs.
function bestMovementMatrix(
  whatifValues: SegmentInfo[], firstBetter: (a: number, b: number) => boolean): MovementMatrix {
  const m: MovementMatrix = {};
  for (const s of whatifValues) {
    m[s.segment] = {};
    let b: string | undefined;
    for (const k in s.whatIfs) {
      m[s.segment][k] = 0;
      if (b === undefined || firstBetter(s.whatIfs[k], s.whatIfs[b])) {
        b = k;
      }
    }
    m[s.segment][b!] = 1;
  }
  return m;
}

function flattenDistribution(
  originalDistribution: Distribution, matrix: MovementMatrix): Distribution {
  const cats = Object.keys(matrix);
  const flat: Distribution = {};
  for (const c1 of cats) {
    flat[c1] = 0;
    for (const c2 of cats) {
      flat[c1] += originalDistribution[c2] * matrix[c2][c1];
    }
  }
  return flat;
}

// Cosine similarity between two distributions. It's in [0, 1].
function distributionSimilarity(a: Distribution, b: Distribution): number {
  let ab = 0;
  let aNorm = 0;
  let bNorm = 0;
  for (const cat in a) {
    ab += a[cat] * b[cat];
    aNorm += a[cat] * a[cat];
    bNorm += b[cat] * b[cat];
  }
  return ab / Math.sqrt(aNorm) / Math.sqrt(bNorm);
}

function identityMatrix(keys: PerCategory<number>): MovementMatrix {
  const m: MovementMatrix = {};
  for (const c1 in keys) {
    m[c1] = {};
    for (const c2 in keys) {
      m[c1][c2] = c1 === c2 ? 1 : 0;
    }
  }
  return m;
}

// Finds a MovementMatrix that is close to the one provided,
// and yields the simulatedDistribution when flattened.
// Assumes that "matrix" has just one 1 in each row.
export function closestMatchingMatrix(
  simulatedDistribution: Distribution, originalDistribution: Distribution,
  matrix: MovementMatrix): MovementMatrix {
  const sd = simulatedDistribution;
  const od = originalDistribution;
  const cols = Object.keys(od);
  const closest: MovementMatrix = Object.fromEntries(cols.map(src => [src, {...matrix[src]}]));
  const flat = flattenDistribution(od, matrix);
  const overflow = Object.fromEntries(cols.map(k => [k, 0]));
  // Step 1: Handle destinations where the matrix is greater than sd.
  // We compute the final values for these and save the "overflow".
  for (const dst in matrix) {
    if (flat[dst] > sd[dst]) { // Must decrease this column to match sd[dst].
      // Leave as much in the diagonal as possible.
      const diagonal = od[dst] * matrix[dst][dst];
      if (sd[dst] < diagonal) {
        for (const src in matrix) {
          if (src === dst) {
            closest[src][dst] = sd[dst] / od[src]; // Reduce diagonal.
          } else {
            closest[src][dst] = 0; // Others go to zero.
          }
          overflow[src] += matrix[src][dst] - closest[src][dst];
        }
      } else {
        const reduction = (sd[dst] - diagonal) / (flat[dst] - diagonal);
        for (const src in matrix) {
          if (src === dst) {
            closest[src][dst] = matrix[src][dst]; // Keep the diagonal element.
          } else {
            closest[src][dst] = matrix[src][dst] * reduction; // Reduce the others.
          }
          overflow[src] += matrix[src][dst] - closest[src][dst];
        }
      }
    }
  }
  const underflow = Object.fromEntries(cols.map(k => [k, 0]));
  // Step 2: For the other columns we first look at the diagonal and put as much of the
  // overflow here as we can fit.
  for (const k in matrix) {
    if (flat[k] < sd[k]) {
      const diagonal = matrix[k][k];
      if (overflow[k] * od[k] < sd[k] - flat[k]) {
        closest[k][k] = diagonal + overflow[k]; // Put all the overflow here but it's still not enough.
      } else {
        closest[k][k] = (sd[k] - flat[k]) / od[k]; // We can cover it and still have overflow left.
      }
      const increase = closest[k][k] - diagonal;
      overflow[k] -= increase;
      underflow[k] = sd[k] - flat[k] - od[k] * increase;
    }
  }
  const totalUnder = Object.values(underflow).reduce((a, b) => a + b, 0);
  // Step 3: In the columns that are still not satisfied, we distribute the remaining overflow evenly.
  for (const dst in matrix) {
    if (underflow[dst] > 0) {
      for (const src in matrix) {
        if (src !== dst) {
          closest[src][dst] = matrix[src][dst] + overflow[src] * underflow[dst] / totalUnder;
        }
      }
    }
  }
  // Those steps don't quite affect rows that have no people. The matrix doesn't really
  // matter for these, but for consistency's sake we set it to the identity here.
  for (const src in matrix) {
    if (od[src] === 0) {
      for (const dst in matrix) {
        closest[src][dst] = src === dst ? 1 : 0;
      }
    }
  }
  return closest;
}

// Similarities are in [0, 1]. We normalize them to have a sum of 1,
// but in a way that keeps a 1 as 1. (Pushing others to 0.)
// This makes sure a perfect match gives maximal weight.
export function normalizeSimilarities(sims: number[]): number[] {
  // Handle edge case, when everything is zero.
  if (sims.find(x => x > 0) === undefined) {
    return sims.map(() => 1 / sims.length);
  }
  // A normalization with these properties is to find x for which a^x + b^x + c^x = 1.
  // We're doing a binary search because I couldn't find an exact formula.
  // We'll go with x = 2^y parameterization and search for y in [-10, 10].
  let low = -30;
  let high = 30;
  while (high - low > 0.01) {
    const y = (high + low) / 2;
    const x = Math.pow(2, y);
    let sum = 0;
    for (const s of sims) {
      sum += Math.pow(s, x);
    }
    if (sum > 1) {
      low = y;
    } else {
      high = y;
    }
  }
  const weights = sims.map(s => Math.pow(s, Math.pow(2, (high + low) / 2)));
  // Correct for edge cases and inaccuracies.
  const sum = weights.reduce((a, b) => a + b, 0);
  return weights.map(w => w / sum);
}

// Weighted sum of matrixes.
function linearMatrixCombination(weights: number[], matrixes: MovementMatrix[]): MovementMatrix {
  const c: MovementMatrix = {};
  for (const src in matrixes[0]) {
    c[src] = {};
    for (const dst in matrixes[0]) {
      c[src][dst] = 0;
      for (const i in weights) {
        c[src][dst] += weights[i] * matrixes[i][src][dst];
      }
    }
  }
  return c;
}

// Weighted sum of vectors.
function linearVectorCombination(
  weights: number[], vectors: PerCategory<number>[]): PerCategory<number> {
  const c: PerCategory<number> = {};
  for (const x in vectors[0]) {
    c[x] = 0;
    for (const i in weights) {
      c[x] += weights[i] * vectors[i][x];
    }
  }
  return c;
}

// Finds a movement matrix that yields the simulatedDistribution when flattened.
// There are infinite many such matrixes. We select one on a smooth manifold that goes
// through the matrixes associated with the highest and lowest outcome value and the current
// distribution.
export function movementForDistribution(
  whatifValues: SegmentInfo[],
  originalDistribution: Distribution, simulatedDistribution: Distribution): MovementMatrix {
  const od = originalDistribution;
  const sd = simulatedDistribution;
  fillMissing(whatifValues, Object.keys(od));
  // The three matrixes we want to make reachable.
  const maxMatrix = bestMovementMatrix(whatifValues, (a, b) => a > b);
  const minMatrix = bestMovementMatrix(whatifValues, (a, b) => a < b);
  const origMatrix = identityMatrix(od);
  // How far we are from each flat distribution?
  const maxFlat = flattenDistribution(od, maxMatrix);
  const minFlat = flattenDistribution(od, minMatrix);
  const origSim = distributionSimilarity(sd, od);
  const maxSim = distributionSimilarity(sd, maxFlat);
  const minSim = distributionSimilarity(sd, minFlat);
  // What are the closest matrixes for each point that match the simulated distribution?
  const maxClosest = closestMatchingMatrix(sd, od, maxMatrix);
  const minClosest = closestMatchingMatrix(sd, od, minMatrix);
  const origClosest = closestMatchingMatrix(sd, od, origMatrix);
  const weights = normalizeSimilarities([maxSim, minSim, origSim]);
  const c = linearMatrixCombination(weights, [maxClosest, minClosest, origClosest]);
  return c;
}

// This is needed to generate a "full" movement matrix.
// For zero-sized segments we just add zeroes
function fillMissing(whatifValues: SegmentInfo[], categories: string[]) {
  const nonEmptySegments = whatifValues.map(info => info.segment);
  for (const cat of categories) {
    if ( nonEmptySegments.indexOf(cat) === -1 ) {
      const whatIfs: PerCategory<number> = {};
      for (const cat of categories) {
        whatIfs[cat] = 0.0;
      }
      const info = { segment: cat, whatIfs};
      whatifValues.push(info);
    }
  }
}

// Returns the outcome for one ingredient when it's set to the given distribution.
export function simulateOutcomeForIngredient(
  whatifValues: SegmentInfo[],
  originalDistribution: Distribution, simulatedDistribution: Distribution): number {
  const movement = movementForDistribution(whatifValues, originalDistribution, simulatedDistribution);
  return simulateOutcomeForMovement(whatifValues, originalDistribution, movement);
}

// Returns the outcome for one ingredient when customers are moved according to the matrix.
export function simulateOutcomeForMovement(
  whatifValues: SegmentInfo[],
  originalDistribution: Distribution, movement: MovementMatrix): number {
  const catValues = Object.keys(originalDistribution);
  // "Unpack" necessary whatif per segment averages
  const perSegmentAverages: { [catValue: string]: { [catValue: string]: number }} = {};
  for (const segmentInfo of whatifValues) {
    perSegmentAverages[segmentInfo.segment] = segmentInfo.whatIfs;
  }
  let simulatedValue: number = 0;
  for (const v of catValues) {
    let perCatSimulatedValue = 0;
    if ( perSegmentAverages[v] !== undefined) {
      for (const w of catValues) {
        perCatSimulatedValue += movement[v][w] * perSegmentAverages[v][w];
      }
    }
    // We need to use the original distribution for the weighted avg.
    simulatedValue += originalDistribution[v] * perCatSimulatedValue;
  }
  return simulatedValue;
}

export function clipOutcome(value: number, outcome: Outcome): number {
  let clipped = value;
  if (outcome.minValue !== undefined && clipped < outcome.minValue) {
    clipped = outcome.minValue;
  }
  if (outcome.maxValue !== undefined && clipped > outcome.maxValue) {
    clipped = outcome.maxValue;
  }
  return clipped;
}

export function simulateOutcome(
  simulatorState: SimulatorState,
  whatIfsPerIng: PerIngredient<SegmentInfo[]>,
  currentOutcomeValue: number,
  outcome: Outcome,
  // If ingsInGroup is defined, we simulate only for a group
  ingsInGroup?: Array<string>): number {
  // This is where we use the "multiplicative/relative" model.
  let simulatedValue = currentOutcomeValue;
  const ingIds = ingsInGroup !== undefined ? ingsInGroup : Object.keys(simulatorState.simulatedDistributions);
  for (const ingId of ingIds) {
    const simulatedValueForIng = simulateOutcomeForIngredient(
      whatIfsPerIng[ingId],
      simulatorState.originalDistributions[ingId],
      simulatorState.simulatedDistributions[ingId]
    );
    simulatedValue *= simulatedValueForIng / currentOutcomeValue;
  }
  return clipOutcome(simulatedValue, outcome);
}

export function computeDistributionFromCounts(
    ingredient: Ingredient, counts: CategoryCounts): PerCategory<number> {
  const pcts: { [catValue: string]: number} = {};
  for (const cat of ingredient.categories) {
    pcts[cat.value] = 0;
  }
  let sum = 0;
  for (const cnt of counts) {
    const catValue = cnt[ingredient.column];
    const count = cnt.count;
    pcts[catValue] = count;
    sum += count;
  }
  for (const cat of ingredient.categories) {
    pcts[cat.value] /= sum;
  }
  return pcts;
}

export function initialDistributions(
    sqIngredients: Array<Ingredient>,
    categoryCounts: PerIngredient<CategoryCounts>): PerIngredient<PerCategory<number>> {
  const initialDistributions: PerIngredient<PerCategory<number>> = {};
  for (const ing of sqIngredients) {
    initialDistributions[ing.column] = computeDistributionFromCounts(
      ing, categoryCounts[ing.column]);
  }
  return initialDistributions;
}

// Finds a distribution that returns the simulatedValue.
export function backwardSimulation(
  whatifValues: SegmentInfo[],
  originalDistribution: Distribution, simulatedValue: number): Distribution {
  const od = originalDistribution;
  // The min/max distributions.
  const maxMatrix = bestMovementMatrix(whatifValues, (a, b) => a > b);
  const maxFlat = flattenDistribution(od, maxMatrix);
  const minMatrix = bestMovementMatrix(whatifValues, (a, b) => a < b);
  const minFlat = flattenDistribution(od, minMatrix);
  // Binary search the mixing of minFlat (at -1), od (at 0), and maxFlat (at 1)
  // to find the linear combination that is closest to simulatedValue.
  const epsilon = 0.01;
  let low = -1 - epsilon;
  let high = 1 + epsilon;
  let simulatedDistribution: Distribution = {};
  while (high - low > epsilon) {
    const mix = (high + low) / 2;
    const clamped = Math.min(1, Math.max(-1, mix));
    simulatedDistribution = linearVectorCombination(
      [1 - Math.abs(clamped), Math.abs(clamped)],
      mix < 0 ? [od, minFlat] : [od, maxFlat]);
    const v = simulateOutcomeForIngredient(whatifValues, od, simulatedDistribution);
    if (v < simulatedValue) {
      low = mix;
    } else {
      high = mix;
    }
  }
  return simulatedDistribution;
}
