// @ts-ignore
import * as ml5 from "ml5";

import type { GestureData } from "../../../modules/launcher/types";

type InputFiles = {
  model: string;
  metadata: string;
  weights: string;
};

type NeuralNetwork = {
  addData: (xs: Array<any> | Object, ys: Array<any> | Object) => void;
  classify: (
    inputs: Array<any> | Object,
    callback: (error: string | undefined, result: any) => void
  ) => void;
  load: (filesOrPath: string | InputFiles, callback?: () => void) => void;
  normalizeData: () => void;
  save: (outputName: string, callback?: () => void) => void;
  train: (
    options: { batchSize?: number; epochs?: number },
    callback?: () => void
  ) => void;
};

class NeuralNet {
  private nn: NeuralNetwork;

  constructor(
    pathToModelJson: string,
    onLoaded?: Parameters<NeuralNetwork["load"]>[1]
  ) {
    this.nn = this.createNet();
    this.nn.load(pathToModelJson, onLoaded);
  }

  classify(
    sample: GestureData,
    callback: Parameters<NeuralNetwork["classify"]>[1]
  ) {
    const result = this.formatData([sample]);
    if (result.length > 0) {
      const { label, ...data } = result[0];
      this.nn.classify(data, callback);
    } else {
      callback("No data", null);
    }
  }

  save() {
    this.nn.save("handGestureModel");
  }

  train(
    samples: GestureData[],
    onTrained?: Parameters<NeuralNetwork["train"]>[1]
  ) {
    const data = this.formatData(samples);
    const tempNn = this.createNet();

    data.forEach(({ label, ...inputs }) => {
      tempNn.addData(inputs, { label });
    });
    tempNn.normalizeData();

    tempNn.train({ batchSize: 12, epochs: 200 }, () => {
      this.nn = tempNn;
      onTrained?.();
    });
  }

  private createNet(): NeuralNetwork {
    return ml5.neuralNetwork({
      debug: true,
      layers: [
        {
          type: "dense",
          units: 24,
          activation: "relu",
        },
        {
          type: "dense",
          units: 16,
          activation: "relu",
        },
        {
          type: "dense",
          activation: "softmax",
        },
      ],
      task: "classification",
    });
  }

  private formatData(samples: GestureData[]) {
    const data: { label: string }[] = [];
    samples.forEach((sample) => {
      if (
        sample.model === "hands" &&
        sample.results.multiHandLandmarks &&
        sample.results.multiHandLandmarks.length > 0
      ) {
        const firstHand = sample.results.multiHandLandmarks[0];
        data.push({
          // @ts-ignore
          label: sample.label,
          ...firstHand.reduce(
            (prev, val, idx) => ({
              ...prev,
              [`${idx}_x`]: val.x,
              [`${idx}_y`]: val.y,
              [`${idx}_z`]: val.z,
            }),
            {}
          ),
        });
      }
    });
    return data;
  }
}

export default NeuralNet;
