/* global BigInt */

import * as tf from "@tensorflow/tfjs";
import * as ort from "onnxruntime-web/all";
import { getData, getDataAsJson, getDataAsText } from "./cache.js";
import { frame_samples } from "./consts.js";
import {
  LOG_EPS,
  logaddexp,
  ml_context,
  onnx2tf,
  onnx_session_opts,
  tf2onnx,
  to_int_list,
} from "./tensor_utils.js";

import * as perf from "./perf.js";

const batch_size = 1;

// const asr_feature_window_size = 32;
// const asr_audio_window_size = (asr_feature_window_size + 2) * frame_samples;
// + 2 to add some padding otherwise it would generate only len_accepty 32 frames
// still it may not be perfect and better implementation may be needed

class SymbolTable {
  constructor(tokens_fn) {
    this.token2idx = {};
    this.idx2token = {};

    getDataAsText(tokens_fn)
      .then((text) => {
        var lines = text.split(/\r\n|\n/);
        // console.log(lines);
        for (var line = 0; line < lines.length - 1; line++) {
          let [token, idx] = lines[line].split(" ");
          idx = parseInt(idx);
          this.token2idx[token] = idx;
          this.idx2token[idx] = token;
        }
      })
      .catch((e) => console.error(e));
  }

  get_token(idx) {
    return this.idx2token[idx];
  }

  get_idx(token) {
    return this.token2idx[token];
  }

  decode_bpe(hyp, context_size) {
    let text = [];
    hyp.slice(context_size).forEach((token, index, arr) => {
      //   console.log('token', token, 'bpe', this.get_token(token));
      text.push(this.get_token(token));
    });
    text = text.join("");
    text = text.replaceAll(String.fromCharCode(9601), " ");
    text = text.trim();
    text = text.toLowerCase();

    return text;
  }

  decode_nbest_bpe(hyps, context_size, topk) {
    let nbest = [];
    hyps = hyps.topk(topk, false);

    // console.log(hyps);

    hyps.forEach((v) => {
      const h_text = this.decode_bpe(v.ys, context_size);
      const h_prob = Math.exp(v.log_prob);
      nbest.push([h_prob, h_text]);
      //   console.log(`nbest ${h_prob} ${h_text}`);
    });

    return nbest;
  }
}

class Hypothesis {
  constructor(ys, log_prob) {
    this.ys = ys;
    this.log_prob = log_prob;
  }

  key() {
    return this.ys.join("_");
  }
}

class HypothesisList {
  constructor() {
    this.data = new Map();
  }

  length() {
    return this.data.size;
  }

  values() {
    return Array.from(this.data.values());
  }

  add(new_hyp) {
    // Add hypothesis

    const key = new_hyp.key();

    if (this.data.has(key)) {
      let old_hyp = this.data.get(key);
      old_hyp.log_prob = logaddexp(old_hyp.log_prob, new_hyp.log_prob);
    } else {
      this.data.set(key, new_hyp);
    }
  }

  get_most_probable(length_norm) {
    return this.topk(1, length_norm)[0];
  }

  topk(k, length_norm) {
    let hyps = Array.from(this.data.values());

    if (length_norm) {
      const s = hyps.sort((a, b) => {
        if (a.log_prob / a.ys.length < b.log_prob / b.ys.length) {
          return 1;
        } else if (a.log_prob / a.ys.length > b.log_prob / b.ys.length) {
          return -1;
        } else {
          return 0;
        }
      });

      return s.slice(0, k);
    } else {
      const s = hyps.sort((a, b) => {
        if (a.log_prob < b.log_prob) {
          return 1;
        } else if (a.log_prob > b.log_prob) {
          return -1;
        } else {
          return 0;
        }
      });

      return s.slice(0, k);
    }
  }
}

class ASR {
  params = undefined;
  model_path = undefined;

  fbank_session_opts = undefined;
  encoder_session_opts = undefined;
  decoder_session_opts = undefined;
  joiner_session_opts = undefined;
  ml_context = undefined;

  constructor(params, model_path) {
    this.params = params;
    console.log("ASR params", this.params);

    this.model_path = model_path;
  }

  async destroy() {
    // this is a placeholder for any cleanup that may be needed
    // thsi can be overriden in the derived classes
  }

  async load_models(progressHandler) {
    this.ml_context = await ml_context(this.params.encoder_execution_provider); // FIXME: only depending on encoder ???

    this.fbank_session_opts = onnx_session_opts(
      this.params.fbank_execution_provider, 
      this.ml_context
    );
    this.encoder_session_opts = onnx_session_opts(
      this.params.encoder_execution_provider,
      this.ml_context
    );
    this.decoder_session_opts = onnx_session_opts(
      this.params.decoder_execution_provider,
      this.ml_context
    );
    this.joiner_session_opts = onnx_session_opts(
      this.params.joiner_execution_provider,
      this.ml_context
    );

    progressHandler({ message: "Loading model meta" });
    console.log("Loading model meta");
    console.log(this.model_path);
    this.modelmeta = await getDataAsJson(this.model_path + "/modelmeta.json");

    if (this.params.encoder_model_type == "default") {
      this.encoder_meta = this.modelmeta.encoder;
    } else {
      this.encoder_meta = this.modelmeta[`encoder_${this.params.encoder_model_type}`];
    }
    this.decoder_meta = this.modelmeta.decoder;
    this.joiner_meta = this.modelmeta.joiner;

    this.tokens_fn = this.model_path + "/tokens.txt";
    this.fbank_onnx_fn = this.model_path + "/tack_fbank.onnx";
    this.encoder_onnx_fn =
      this.model_path + "/" + this.encoder_meta.model_basename;
    this.decoder_onnx_fn =
      this.model_path + "/" + this.decoder_meta.model_basename;
    this.joiner_onnx_fn =
      this.model_path + "/" + this.joiner_meta.model_basename;

    progressHandler({ message: "Loading token table" });
    console.log("Loading token table");
    this.symbol_table = new SymbolTable(this.tokens_fn);

    progressHandler({ message: "Loading fbank model" });
    console.log("Loading fbank model");
    const fbank_onnx = await getData(this.fbank_onnx_fn);
    const session_opts = { ...this.fbank_session_opts };
    session_opts.freeDimensionOverrides = {
      // audio_length: asr_audio_window_size,
      // fbank_length: asr_feature_window_size,
    };
    // session_opts.logSeverityLevel = 0;
    console.log("fbank session_opts", session_opts);
    this.fbank_session = await ort.InferenceSession.create(
      fbank_onnx,
      session_opts
    );
  }

  async init() {
    this.waveform = [];
    this.frames = [];
    this.len_accept = 0;
    this.is_input_finished = false;
    this.is_input_decoded = false;
    this.num_decoded_frames = 0;

    // # The decoding result (partial or final) of current utterance.
    this.hyp = [];
    this.hyps = new HypothesisList();
    this.hyps.add(
      new Hypothesis([this.params.blank_id, this.params.blank_id], 0.0)
    );
  }

  async accept_waveform(waveform) {
    this.len_accept += waveform.length;
    this.waveform.push(...waveform);

    // console.log("asr.js:accept_waveform this.waveform.length", this.waveform.length, );

    while (true) {
      if (
        this.waveform.length - this.frames.length * frame_samples <
        // asr_audio_window_size
        4 * frame_samples
      ) {
        break;
      }

      const slice_start = this.frames.length * frame_samples;
      // console.log('slice_start', slice_start)
      // const slice_end = slice_start + asr_audio_window_size;
      const slice_end = slice_start + this.waveform.length - this.frames.length * frame_samples;
      // console.log('slice_end', slice_end)

      const unprocessed_waveform = this.waveform.slice(slice_start, slice_end);
      // console.log('unprocessed_waveform.length', unprocessed_waveform.length)

      // console.log("asr.js:accept_waveform slice_start", slice_start);
      // console.log("asr.js:accept_waveform slice_end", slice_end);
      // console.log("asr.js:accept_waveform unprocessed_waveform.length", unprocessed_waveform.length);

      const waveform_tensor = new ort.Tensor("float32", unprocessed_waveform, [
        1,
        unprocessed_waveform.length,
      ]);

      perf.time("asr.run_fbank");
      const fbank_results = await this.fbank_session.run({
        audio: waveform_tensor,
      });
      perf.timeEnd("asr.run_fbank");
      // console.log('fbank_results.fbank.dims', fbank_results.fbank.dims);

      const fbank = await onnx2tf(fbank_results.fbank);
      // fbank.print();
      tf.unstack(fbank).forEach((tensor) => this.frames.push(tensor));
    }

    // console.log('asr.js:accept_waveform frames.length', this.frames.length);
  }

  async input_finished() {
    // console.log("input_finished:start ");
    const tail_padding = new Array(frame_samples * this.decode_chunk_len * 2)
      .fill()
      .map(() => 0.001 * Math.random());
    await this.accept_waveform(tail_padding);
    // console.log('len_accept', this.len_accept)
    this.is_input_finished = true;
    // console.log("input_finished:end");
  }

  async get_feature_frames() {
    const chunk_length = this.decode_chunk_len + this.pad_length;

    let return_one_chunk = true;

    if (!this.is_input_finished) {
      if (this.frames.length - this.num_decoded_frames < chunk_length) {
        return [null, 0];
      } else {
        return_one_chunk = true;
      }
    } else {
      if (this.frames.length - this.num_decoded_frames < chunk_length) {
        return_one_chunk = false;
      } else {
        return_one_chunk = true;
      }
    }

    // console.log(this.frames.length, this.num_decoded_frames, this.frames.length - this.num_decoded_frames, chunk_length);

    let ret_features = [];
    let ret_length = 0;

    if (return_one_chunk) {
      // just one chunk
      ret_length = chunk_length;
      ret_features = this.frames.slice(
        this.num_decoded_frames,
        this.num_decoded_frames + ret_length
      );
      ret_features = tf.stack(ret_features);
      this.num_decoded_frames += this.decode_chunk_len;
    } else {
      // the rest of what is available
      ret_length = this.frames.length - this.num_decoded_frames;
      ret_features = this.frames.slice(
        this.num_decoded_frames,
        this.num_decoded_frames + ret_length
      );
      ret_features = tf.stack(ret_features);
      this.num_decoded_frames += ret_length;
    }

    if (
      this.is_input_finished === true &&
      this.num_decoded_frames >= this.frames.length
    ) {
      //logging('log1',`is_input_decoded = True | ${this.num_decoded_frames} ${this.online_fbank.num_frames_ready}`)
      this.is_input_decoded = true;
    }

    // console.log(ret_features.shape, ret_length);

    return [ret_features, ret_length];
  }

  async greedy_search(encoder_out, hyp) {
    // Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
    // Args:
    //   model:
    //     The transducer model.
    //   encoder_out:
    //     A 3-D tensor of shape (1, T, joiner_dim)
    //   hyp:
    //     Decoding results for previous chunks.
    // Returns:
    //   Return the decoded results so far.

    perf.time("asr.greedy_search");

    const blank_id = 0;
    perf.time("asr.greedy_search.encoder_out");
    encoder_out = tf.squeeze(await onnx2tf(encoder_out, this.ml_context), 0);
    perf.timeEnd("asr.greedy_search.encoder_out");
    const T = encoder_out.shape[0];

    if (hyp == null || hyp.length === 0) {
      hyp = [blank_id, blank_id];
    }

    // console.log('hyp', hyp);

    let decoder_input = new ort.Tensor(
      "int64",
      [BigInt(hyp[hyp.length - 2]), BigInt(hyp[hyp.length - 1])],
      [1, 2]
    );
    let decoder_out = await this.run_decoder(decoder_input);

    for (let t = 0; t < T; t++) {
      let cur_encoder_out = encoder_out.slice(t, 1);
      cur_encoder_out = tf2onnx("float32", cur_encoder_out);
      let joiner_out = await this.run_joiner(cur_encoder_out, decoder_out);
      joiner_out = tf.squeeze(await onnx2tf(joiner_out), 0);
      let y = joiner_out.argMax(0).dataSync()[0];

      if (y !== blank_id) {
        hyp.push(y);
        decoder_input = new ort.Tensor(
          "int64",
          [BigInt(hyp[hyp.length - 2]), BigInt(hyp[hyp.length - 1])],
          [1, 2]
        );
        decoder_out = await this.run_decoder(decoder_input);
      }
    }

    perf.timeEnd("asr.greedy_search");
    return hyp;
  }

  async modified_beam_search(encoder_out, hyps, num_active_paths) {
    // Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
    //
    // Args:
    //   model:
    //     The transducer model.
    //   encoder_out:
    //     A 3-D tensor of shape (1, T, joiner_dim)
    //   hyps:
    //     Decoding results for previous chunks.
    //   num_active_paths:
    //     Number of active paths during the beam search.
    //

    perf.time("asr.modified_beam_search");

    const blank_id = 0;
    perf.time("asr.modified_beam_search.encoder_out.modified_beam_search");
    encoder_out = tf.squeeze(await onnx2tf(encoder_out, this.ml_context), 0);
    perf.timeEnd("asr.modified_beam_search.encoder_out.modified_beam_search");

    const T = encoder_out.shape[0];

    if (hyps == null || hyps.length === 0) {
      hyps = new HypothesisList();
      hyps.add(new Hypothesis([blank_id, blank_id], 0.0));
    }
    // console.log('hyps', hyps);

    let new_hyps = hyps;

    for (let t = 0; t < T; t++) {
      let cur_encoder_out = encoder_out.slice(t, 1);
      cur_encoder_out = tf.squeeze(cur_encoder_out, 0);

      const len_hyps = new_hyps.length();

      const old_hyps = new_hyps.values();
      // console.log('old_hyps', old_hyps);
      new_hyps = new HypothesisList();

      const ys_log_probs_x = old_hyps.map((a) => [a.log_prob]);
      const ys_log_probs = tf.tensor2d(ys_log_probs_x);

      let all_hyp_decoder_input = [];
      for (let i = 0; i < old_hyps.length; i++) {
        const hyp = old_hyps[i];
        // console.log('hyp', i, hyp);
        all_hyp_decoder_input.push(BigInt(hyp.ys[hyp.ys.length - 2]));
        all_hyp_decoder_input.push(BigInt(hyp.ys[hyp.ys.length - 1]));
      }

      let decoder_input = new ort.Tensor("int64", all_hyp_decoder_input, [
        len_hyps,
        2,
      ]);
      let decoder_out = await this.run_decoder(decoder_input);

      let all_hyp_cur_encoder_out = [];
      for (let i = 0; i < len_hyps; i++) {
        all_hyp_cur_encoder_out.push(cur_encoder_out);
      }

      cur_encoder_out = tf.stack(all_hyp_cur_encoder_out);
      cur_encoder_out = tf2onnx("float32", cur_encoder_out);

      let logits = await this.run_joiner(cur_encoder_out, decoder_out);
      // logits = tf.squeeze(tf.squeeze(onnx2tf(joiner_out), 1), 1);
      logits = await onnx2tf(logits);

      let log_probs = logits.logSoftmax(-1);

      // console.log('-----------------------')
      // ys_log_probs.exp().print();
      log_probs = log_probs.add(ys_log_probs);

      const vocab_size = log_probs.shape[1];

      log_probs = log_probs.reshape([-1]);

      const { values, indices } = tf.topk(log_probs, num_active_paths);
      // values.exp().print();
      // indices.print();
      const topk_log_probs = values.arraySync();
      const topk_indexes = indices;

      let topk_hyp_indexes = topk_indexes.floorDiv(vocab_size).arraySync();
      let topk_token_indexes = topk_indexes.mod(vocab_size).arraySync();

      // console.log(topk_hyp_indexes);
      // console.log(topk_token_indexes);

      for (let k = 0; k < topk_hyp_indexes.length; k++) {
        const hyp_idx = topk_hyp_indexes[k];
        const hyp = old_hyps[hyp_idx];

        let new_ys = [...hyp.ys];
        const new_token = topk_token_indexes[k];
        if (new_token !== blank_id) {
          new_ys.push(new_token);
        }

        const new_log_prob = topk_log_probs[k];
        const new_hyp = new Hypothesis(new_ys, new_log_prob);
        new_hyps.add(new_hyp);
      }
    }

    perf.timeEnd("asr.modified_beam_search");

    return new_hyps;
  }

  result(nbest = false) {
    perf.time("asr.result");
    // Obtain current decoding result.

    let r = null
    if (this.params.decoding_method === "greedy_search") {
      r =  this.symbol_table.decode_bpe(this.hyp, this.context_size);
    } else if (this.params.decoding_method === "modified_beam_search") {
      let best_hyp = this.hyps.get_most_probable(false);
      if (!nbest) {
        r =  this.symbol_table.decode_bpe(best_hyp.ys, this.context_size);
      } else {
        r =  this.symbol_table.decode_nbest_bpe(
          this.hyps,
          this.context_size,
          this.params.topk
        );
      }
    } else {
      throw new Error(`Invalid decoding_method ${this.params.decoding_method}`);
    }

    perf.timeEnd("asr.result");
    return r;
  }

}

export {
  ASR,
  batch_size,
  Hypothesis,
  HypothesisList,
  LOG_EPS,
  logaddexp,
  SymbolTable,
  to_int_list
};

