import silhouetteScore from '@robzzson/silhouette';
import { Matrix, SingularValueDecomposition } from 'ml-matrix';

import { defaultColorPalette, colorToTransparent } from './colormap';
import { abs, argMax, getGroupedMap, meanAxis } from './utils';
export const getCqs = (a: number[][], labels: string[]): number => {
  const uniqueLanbels = Array.from(new Set(labels));
  const numericLabels: number[] = labels.map((e) => uniqueLanbels.indexOf(e));
  return 100 * silhouetteScore(a, numericLabels);
};

/*
Computes the statistical certainty ellipses around a cluster of points
Should be calculated for each group separately
Returns 5 parameters of an ellipsis: centerX, centerY, radiusX, radiusY, angle (counter-clockwise in radians)
*/
export const getGaussianEllipse = (a: number[][]): number[] => {
  const aMat = new Matrix(a);

  const nSamples = aMat.rows;

  const mean = meanAxis(a, 0) as number[];
  const [meanX, meanY] = mean;
  const meanMat = new Matrix([mean]);

  const diff = aMat.subRowVector(meanMat);

  const covar = diff.transpose().mmul(diff).div(nSamples);
  const e = new SingularValueDecomposition(covar);
  const eigval = e.diagonal;
  const eigvecR = e.rightSingularVectors;

  const signs = [];
  for (let i = 0; i < 2; i++) {
    const col = eigvecR.getRow(i);
    const am = argMax(abs(col));
    if (col[am] < 0) {
      signs.push(-1);
    } else {
      signs.push(1);
    }
  }
  // eigvecR = eigvecR.mulColumn(0, signs[0])
  // eigvecR = eigvecR.mulColumn(1, signs[1])

  // We can calculate the standard deviation this way, using only covariances
  // const stdevX = Math.sqrt(covar.get(0, 0) as number)
  // const stdevY = Math.sqrt(covar.get(1, 1) as number)

  // Or this way as described here:
  // https://scikit-learn.org/stable/auto_examples/mixture/plot_gmm.html

  eigval.sort((a, b) => a - b);
  const stdevX = Math.sqrt(2) * Math.sqrt(eigval[1]);
  const stdevY = Math.sqrt(2) * Math.sqrt(eigval[0]);

  const un = eigvecR.getRowVector(0);
  const ud = un.norm('frobenius');
  const u = un.div(ud).to1DArray();
  const angle = Math.atan2(u[1], u[0]); //- Math.PI/2
  return [meanX, meanY, stdevX, stdevY, angle];
};

/*
Transform ellipse parameters into an SVG path
*/
export const getEllipseSVGPath = (cx: number, cy: number, a: number, b: number, theta: number): string => {
  const center: number[] = [cx, cy];

  // Compute shifts in both directions
  const longitudinal = [a * Math.cos(-theta), -a * Math.sin(-theta)];
  const latitudinal = [-b * Math.cos(Math.PI / 2 - theta), b * Math.sin(Math.PI / 2 - theta)];

  // See scheme for details on this notation
  // Ellipses extrema:
  const M = center.map((e, i) => e - longitudinal[i]);
  const N = center.map((e, i) => e + longitudinal[i]);

  // Cubic bezier curve anchors:
  const A = center.map((e, i) => e - longitudinal[i] + latitudinal[i] / 0.75);
  const B = center.map((e, i) => e + longitudinal[i] + latitudinal[i] / 0.75);
  const D = center.map((e, i) => e - longitudinal[i] - latitudinal[i] / 0.75);
  return `M ${M[0]},${M[1]} C ${A[0]},${A[1]} ${B[0]},${B[1]} ${N[0]},${N[1]} S ${D[0]},${D[1]} ${M[0]},${M[1]}`;
};

/*
Scales ellipse's rx and ry radii by a scale factor
Returns a copy of an ellipse
*/
export const scaleEllipse = (ellipse: number[], scale: number): number[] => {
  const [cx, cy, rx, ry, theta] = ellipse;
  return [cx, cy, rx * scale, ry * scale, theta];
};

export const scaleEllipsesSigma = (groupedEllipses: Record<string, number[]>, scaleFactor: number) => {
  const ellipses = Object.values(groupedEllipses);
  const newEllipses: number[][] = [];
  ellipses.forEach((ellipse) => {
    newEllipses.push(scaleEllipse(ellipse, scaleFactor));
  });
  const groupedNewEllipsesA = getGroupedMap(Object.keys(groupedEllipses), newEllipses) as Record<string, number[][]>;
  const groupedNewEllipses: Record<string, number[]> = {};
  Object.entries(groupedNewEllipsesA).forEach(([k, v]) => {
    groupedNewEllipses[k] = v[0];
  });
  return groupedNewEllipses;
};

export const getPcaFigure = (
  title: string,
  cmap: Record<string, string>,
  pcaExplainedVarianceRatio: number[],
  groupedProjections: Record<string, number[][]>,
  groupedEllipses: Record<string, number[]>,
  optimizedGroupedEllipses: Record<string, number[]>
): any => {
  const shapes: any[] = [];
  Object.entries(optimizedGroupedEllipses).forEach(([group, ellipse]) => {
    const [cx, cy, a, b, theta] = ellipse;
    const path = getEllipseSVGPath(cx, cy, a, b, theta);
    shapes.push({
      type: 'path',
      path,
      fillcolor: colorToTransparent(defaultColorPalette[cmap[group]], 10),
      line: {
        width: 1,
        color: defaultColorPalette[cmap[group]],
      },
    });
  });
  Object.entries(groupedEllipses).forEach(([group, ellipse]) => {
    const [cx, cy, a, b, theta] = ellipse;
    const path = getEllipseSVGPath(cx, cy, a, b, theta);
    shapes.push({
      type: 'path',
      path,
      fillcolor: colorToTransparent(defaultColorPalette[cmap[group]], 10),
      line: {
        width: 1,
        color: defaultColorPalette[cmap[group]],
      },
    });
  });

  const data: any[] = [];
  const layout: any = {
    hovermode: 'closest',
    title,
    xaxis: {
      title: {
        text: 'PCA 1 (' + (100 * pcaExplainedVarianceRatio[0]).toFixed(1) + '%)',
        font: { size: 12 },
      },
    },
    yaxis: {
      title: {
        text: 'PCA 2 (' + (100 * pcaExplainedVarianceRatio[1]).toFixed(1) + '%)',
        font: { size: 12 },
      },
      scaleanchor: 'x',
      scaleratio: 1,
    },
    annotations: [],
    dragmode: 'zoom',
    shapes,
  };
  Object.entries(groupedProjections).forEach(([group, projs]) => {
    data.push({
      type: 'scatter',
      name: group,
      legendgroup: group,
      x: projs.map((e) => e[0]),
      y: projs.map((e) => e[1]),
      mode: 'markers',
      marker: {
        size: 10,
        line: {
          color: defaultColorPalette[cmap[group]],
          width: 2,
        },
        color: colorToTransparent(defaultColorPalette[cmap[group]], 10),
      },
    });
  });
  return { data, layout };
};
