/* global BigInt */

import * as tf from "@tensorflow/tfjs";
import * as ort from "onnxruntime-web/all";
import ITN from "../itn/ITN.js";
import { CapitalizerPipeline } from "../itn/cap.js";
import { getData } from "./cache.js";
import { frame_ms, frame_samples, sample_rate } from "./consts.js";
import { onnx2tf, onnx_session_opts, tf2onnx } from "./tensor_utils.js";

import * as perf from "./perf.js";

const vad_context_ms = 320;
const vad_context_samples = Math.floor(
  (vad_context_ms / 1000) * sample_rate + frame_samples + 39
);
const vad_center_samples =
  1 * Math.floor((vad_context_ms / 1000) * sample_rate);

const vad_audio_window_size =
  vad_context_samples + vad_center_samples + vad_context_samples;
const vad_feature_window_size = 1 * 32;

// console.log("vad_context_samples", vad_context_samples);
// console.log("vad_center_samples", vad_center_samples);
// console.log("vad_audio_window_size", vad_audio_window_size);
// console.log("vad_feature_window_size", vad_feature_window_size);

class VAD {
  params = undefined;
  mfcc_onnx_fn = undefined;
  vad_onnx_fn = undefined;

  mfcc_session_opts = undefined;
  vad_session_opts = undefined;

  constructor(params, mfcc_onnx_fn, vad_onnx_fn) {
    this.params = params;
    console.log("VAD params", this.params);

    this.mfcc_onnx_fn = mfcc_onnx_fn;
    this.vad_onnx_fn = vad_onnx_fn;

    this.itn_service = new ITN();
    this.capitalizer = new CapitalizerPipeline();
  }

  async load_models(progressHandler) {
    this.mfcc_session_opts = onnx_session_opts(
      this.params.mfcc_execution_provider
    );
    this.vad_session_opts = onnx_session_opts(
      this.params.vad_execution_provider
    );

    progressHandler({ message: "Loading VAD models" });
    console.log("Loading models");

    progressHandler({ message: "Loading mfcc model" });
    console.log("Loading mfcc model");
    const mfcc_onnx = await getData(this.mfcc_onnx_fn);
    // this.mfcc_onnx = mfcc_onnx;
    let session_opts = { ...this.mfcc_session_opts };
    session_opts.freeDimensionOverrides = {
      audio_length: vad_audio_window_size, // account for left, center and right context
      mfcc_length: vad_feature_window_size,
    };
    // session_opts.logSeverityLevel = 0;
    console.log("mfcc session_opts", session_opts);
    this.mfcc_session = await ort.InferenceSession.create(
      mfcc_onnx,
      session_opts
    );

    progressHandler({ message: "Loading vad model" });
    console.log("Loading vad model");
    const vad_onnx = await getData(this.vad_onnx_fn);
    session_opts = { ...this.vad_session_opts };
    session_opts.freeDimensionOverrides = {
      mfcc_length: vad_feature_window_size,
    };
    session_opts.graphOptimizationLevel = "disabled";
    // session_opts.logSeverityLevel = 0;
    console.log("vad session_opts", session_opts);
    this.vad_session = await ort.InferenceSession.create(
      vad_onnx,
      session_opts
    );

    await this.itn_service.init();
    await this.capitalizer.init(null, "wasm", "q8");

    progressHandler({
      message: "Loading VAD models done",
    });
    console.log("Loading VAD models done");
  }

  async init() {
    this.waveform = new Array(0);
    this.left_context = new Array(vad_context_samples)
      .fill()
      .map(() => 0.001 * Math.random());
    this.len_accept = 0;
    this.is_input_finished = false;
    this.is_input_decoded = false;
    this.num_decoded_frames = 0;

    this.segment_id = -1;
    this.chunk_id = -1;
    this.start_of_speech = -1;
    this.end_of_speech = -1;

    this.speech_post = new Array(0);
  }

  async accept_waveform(waveform) {
    this.len_accept += waveform.length;
    this.waveform.push(...waveform);
  }

  async input_finished() {
    const vad_context_padding = new Array(
      vad_center_samples + vad_context_samples
    )
      .fill()
      .map(() => 0.001 * Math.random());

    await this.accept_waveform(vad_context_padding);

    // console.log("input_finished", "added padding", vad_context_padding.length);
    this.is_input_finished = true;
  }

  async decode(asrModel, resultsHandler, options = {}) {
    if (this.is_input_decoded) {
      // console.log("input already decoded");
      return 0;
    }

    if (this.waveform.length < vad_center_samples) {
      // we still need more samples to run VAD, at least for center, we add left padding
      // console.log("need more samples", "have", this.waveform.length, "need", vad_center_samples);
      return 0;
    }

    const num_available_samples = vad_center_samples;

    // we have enough samples to run VAD for left, center, right and need to pad left context
    const center = this.waveform.slice(0, num_available_samples);
    const right_context = new Array(vad_context_samples).fill(0.0);
    // console.log("left_context", this.left_context.length);
    // console.log("center", center.length);
    // console.log("right_context", right_context.length);

    const mfcc_waveform = this.left_context.concat(center, right_context);
    // console.log("mfcc_waveform", mfcc_waveform.length);

    // update left context for next run
    this.left_context = this.left_context.concat(center);
    // console.log("this.left_context", this.left_context.length);
    this.left_context = this.left_context.slice(
      this.left_context.length - vad_context_samples,
      this.left_context.length
    );
    // console.log("this.left_context", this.left_context.length);

    // console.log("this.waveform.length", this.waveform.length);
    this.waveform = this.waveform.slice(
      num_available_samples,
      this.waveform.length
    );
    // console.log("this.waveform.length", this.waveform.length);

    const waveform_tensor = new ort.Tensor("float32", mfcc_waveform, [
      1,
      mfcc_waveform.length,
    ]);

    perf.time("vad.run_mfcc");
    const mfcc_results = await this.mfcc_session.run({
      audio: waveform_tensor,
    });
    perf.timeEnd("vad.run_mfcc");

    // console.log("mfcc_results.mfcc.dims", mfcc_results.mfcc.dims);

    let mfcc = await onnx2tf(mfcc_results.mfcc);
    mfcc = tf.expandDims(mfcc, 0);
    mfcc = tf2onnx("float32", mfcc);

    perf.time("vad.run_vad");
    const vad_results = await this.vad_session.run({
      mfcc: mfcc,
    });
    perf.timeEnd("vad.run_vad");

    const vad_predictions = vad_results.predictions;
    let speech_post = new Array(0);
    // console.log("vad_predictions.dims", vad_predictions.dims);
    for (let i = 0; i < vad_predictions.dims[0]; i++) {
      speech_post.push(vad_predictions.data[i]);
      // this.speech_post.push(vad_predictions.data[i]);
    }

    // console.log("all speech_post", this.speech_post);
    // console.log("speech_post", speech_post);

    const start_of_speech_frames = speech_post.filter(
      (val) => val > this.params.start_of_speech_threshold
    ).length;
    const end_of_speech_frames = speech_post.filter(
      (val) => val < this.params.end_of_speech_threshold
    ).length;

    // console.log("start_of_speech_frames", start_of_speech_frames);
    // console.log("end_of_speech_frames", end_of_speech_frames);

    // console.log("num_decoded_frames", this.num_decoded_frames);

    if (this.start_of_speech === -1 && this.end_of_speech === -1) {
      // no speech yet,  detecting speech start
      if (start_of_speech_frames > 1) {
        // start of speech detected
        this.start_of_speech = this.num_decoded_frames;
        this.segment_id++;
        this.chunk_id++;

        let result = {
          event: "start_of_speech",
          segment_id: this.segment_id,
          chunk_id: this.chunk_id,
          vad: {
            start_of_speech: this.start_of_speech / 100,
            audio: center,
          },
        };

        if (asrModel) {
          await asrModel.init();
          await asrModel.accept_waveform(center);
          while (await asrModel.decode()) {}
          result.asr = {
            text: asrModel.result(),
            decoding_done: false,
          };
        }

        resultsHandler(result);
      }
    } else if (this.start_of_speech >= 0 && this.end_of_speech === -1) {
      this.chunk_id++;

      if (end_of_speech_frames < 30) {
        let result = {
          event: "speech_chunk",
          segment_id: this.segment_id,
          chunk_id: this.chunk_id,
          vad: {
            audio: center,
          },
        };

        if (asrModel) {
          await asrModel.accept_waveform(center);
          while (await asrModel.decode()) {}
          result.asr = {
            text: asrModel.result(),
            decoding_done: false,
          };
        }

        resultsHandler(result);
      } else {
        this.end_of_speech = this.num_decoded_frames + speech_post.length;
        let result = {
          event: "end_of_speech",
          segment_id: this.segment_id,
          chunk_id: this.chunk_id,
          vad: {
            end_of_speech: this.end_of_speech / 100,
            audio: center,
          },
        };

        if (asrModel) {
          await asrModel.accept_waveform(center);
          await asrModel.input_finished();
          while (await asrModel.decode()) {}
          const text = asrModel.result();
          result.asr = {
            text,
            normalized_text: await this._normalizeText(
              text,
              options.convert_commands
            ),
            decoding_done: true,
          };
        }

        resultsHandler(result);

        // comming out of speech
        this.start_of_speech = -1;
        this.end_of_speech = -1;
      }
    }

    this.num_decoded_frames += speech_post.length;

    // check and break too long speech segments
    if (!this.is_input_finished) {
      if (this.start_of_speech >= 0 && this.end_of_speech === -1) {
        // in speech
        if (
          this.num_decoded_frames + speech_post.length - this.start_of_speech >
          this.params.speech_max_length_ms / 10
        ) {
          // to long speech seggment
          console.log("Handling too long speech segment");

          // generate end of speech event

          this.end_of_speech = this.num_decoded_frames;

          // finish the last speech segment without end of speech detected
          // generate end of speech event with empty audio
          let result = {
            event: "end_of_speech",
            segment_id: this.segment_id,
            chunk_id: this.chunk_id,

            vad: {
              end_of_speech: this.end_of_speech / 100,
              audio: [],
            },
          };

          if (asrModel) {
            await asrModel.input_finished();
            while (await asrModel.decode()) {}
            const text = asrModel.result();
            result.asr = {
              text,
              normalized_text: await this._normalizeText(
                text,
                options.convert_commands
              ),
              decoding_done: true,
            };
          }

          resultsHandler(result);

          this.start_of_speech = this.end_of_speech;
          this.end_of_speech = -1;

          this.segment_id++;
          this.chunk_id++;

          // generate start of speech event
          result = {
            event: "start_of_speech",
            segment_id: this.segment_id,
            chunk_id: this.chunk_id,
            vad: {
              start_of_speech: this.start_of_speech / 100,
              audio: [],
            },
          };

          if (asrModel) {
            await asrModel.init();
            while (await asrModel.decode()) {}
            result.asr = {
              text: asrModel.result(),
              decoding_done: false,
            };
          }

          resultsHandler(result);
        }
      }
    }

    if (this.is_input_finished) {
      if (this.start_of_speech >= 0 && this.end_of_speech === -1) {
        this.end_of_speech = this.num_decoded_frames;

        // finish the last speech segment without end of speech detected
        // generate end of speech event with empty audio
        let result = {
          event: "end_of_speech",
          segment_id: this.segment_id,
          chunk_id: this.chunk_id,
          vad: {
            end_of_speech: this.end_of_speech / 100,
            audio: [],
          },
        };

        if (asrModel) {
          await asrModel.input_finished();
          while (await asrModel.decode()) {}
          const text = asrModel.result();
          result.asr = {
            text,
            normalized_text: await this._normalizeText(
              text,
              options.convert_commands
            ),
            decoding_done: true,
          };
        }

        resultsHandler(result);
      }

      this.is_input_decoded = true;

      resultsHandler({
        event: "end_of_stream",
        vad: {
          num_decoded_frames: this.num_decoded_frames,
        },
      });
    }

    return 1;
  }

  async processInSpeechSegment(in_speech, asrModel, resultsHandler) {
    let speech_chunk = [].concat(...in_speech);
    in_speech = [];

    this.chunk_id++;

    let result = {
      event: "speech_chunk",
      segment_id: this.segment_id,
      chunk_id: this.chunk_id,
      vad: {
        audio: speech_chunk,
      },
    };

    if (asrModel) {
      await asrModel.accept_waveform(speech_chunk);
      while (await asrModel.decode()) {}
      result.asr = {
        text: asrModel.result(),
        decoding_done: false,
      };
    }

    resultsHandler(result);

    return in_speech;
  }

  async _normalizeText(text, convertCommands) {
    if (!text) {
      return text;
    }

    const normalizedText = this.itn_service.normalize(text, convertCommands);
    return await this.capitalizer.capitalize(normalizedText);
  }
}

export { frame_ms, VAD };
