/* global BigInt */

import { Float16Array } from "@petamoriken/float16";
import * as tf from "@tensorflow/tfjs";
import * as ort from "onnxruntime-web/all";
import { ASR, batch_size } from "./asr.js";
import { getData } from "./cache.js";
import {
  array2onnx_int64,
  LOG_EPS,
  onnx_tensor_zeros,
  tf2onnx,
  to_int_list,
  zero_out_ml_tensor,
} from "./tensor_utils.js";

import * as perf from "./perf.js";

class IcefallZip2 extends ASR {
  step = 0;
  tensors_created = false;

  async load_models(progressHandler) {
    await super.load_models(progressHandler);

    progressHandler({ message: "Loading IcefallZip2 models" });
    console.log("Loading IcefallZip2 models");

    progressHandler({ message: "Loading encoder model" });
    console.log("Loading encoder model");
    await this.load_encoder();
    progressHandler({ message: "Loading decoder model" });
    console.log("Loading decoder model");
    await this.load_decoder();
    progressHandler({ message: "Loading joiner model" });
    console.log("Loading joiner model");
    await this.load_joiner();

    progressHandler({
      message: "Loading IcefallZip2 models done",
    });
    console.log("Loading IcefallZip2 models done");
  }

  async load_encoder() {
    if (this.encoder_onnx_fn.includes(".fp16.")) {
      this.dtype = "float16";
    } else {
      this.dtype = "float32";
    }

    const encoder_onnx = await getData(this.encoder_onnx_fn);
    const session_opts = { ...this.encoder_session_opts };
    // freeDimensionOverrides change of N breakes the session creation
    // I can export the encoder dynamic dimensions
    // session_opts.freeDimensionOverrides = {
    //   N: 1, // this is number of signals in the batch
    // };
    session_opts.graphOptimizationLevel = "disabled";
    // session_opts.graphOptimizationLevel = 'basic'; # this breakes the session creation
    // session_opts.logSeverityLevel = 0;
    session_opts.logVerbosityLevel = 0;
    console.log("session_opts", session_opts);
    this.encoder = await ort.InferenceSession.create(
      encoder_onnx,
      session_opts
    );

    this.num_encoder_layers = to_int_list(
      this.encoder_meta["num_encoder_layers"]
    );
    this.num_encoders = this.num_encoder_layers.length;
    this.segment = parseInt(this.encoder_meta["T"]);
    this.decode_chunk_len = parseInt(this.encoder_meta["decode_chunk_len"]);
    this.offset = parseInt(this.encoder_meta["decode_chunk_len"]);
    this.chunk_size = Math.trunc(this.offset / 2);
  }

  async load_decoder() {
    const decoder_onnx = await getData(this.decoder_onnx_fn);
    const session_opts = { ...this.decoder_session_opts };

    // NOTE: I will leave it on CPU for now,
    // so no need free dimension overrides

    // session_opts.freeDimensionOverrides = {
    //   N: 1,
    // };
    // session_opts.logSeverityLevel = 0;
    console.log("session_opts", session_opts);
    this.decoder = await ort.InferenceSession.create(
      decoder_onnx,
      session_opts
    );

    this.context_size = parseInt(this.decoder_meta.context_size);
    this.vocab_size = parseInt(this.decoder_meta.vocab_size);
  }

  async load_joiner() {
    const joiner_onnx = await getData(this.joiner_onnx_fn);
    const session_opts = { ...this.joiner_session_opts };

    // NOTE: I will leave it on CPU for now,
    // so no need free dimension overrides

    // session_opts.freeDimensionOverrides = {
    //   N: 1, // this is num_active_paths; however they can vary during decoding
    // };
    // session_opts.logSeverityLevel = 0;
    console.log("session_opts", session_opts);
    this.joiner = await ort.InferenceSession.create(joiner_onnx, session_opts);

    this.joiner_dim = parseInt(this.joiner_meta.joiner_dim);
  }

  async destroy() {
    await super.destroy();

    function destroy_tensors(tensors) {
      for (const name in tensors) {
        const t = tensors[name];

        if (t.location === "ml-tensor") {
          t.mlTensor.destroy();
        }
      }
    }

    destroy_tensors(this.inputs);
    destroy_tensors(this.outputs0);
    destroy_tensors(this.outputs1);
  }

  async init() {
    perf.time("icefall_zip2.init");
    await super.init();
    await this.init_encoder_states();

    this.step = 0;
    perf.timeEnd("icefall_zip2.init");
  }

  async init_encoder_states() {
    const model_type = this.encoder_meta["model_type"];
    if (model_type !== "zipformer2") {
      throw new Error(`model_type ${model_type} is not supported`);
    }

    const decode_chunk_len = parseInt(this.encoder_meta["decode_chunk_len"]);
    this.T = parseInt(this.encoder_meta["T"]);

    this.pad_length = this.T - decode_chunk_len;

    const encoder_dims = to_int_list(this.encoder_meta["encoder_dims"]);
    const cnn_module_kernels = to_int_list(
      this.encoder_meta["cnn_module_kernels"]
    );
    const left_context_len = to_int_list(this.encoder_meta["left_context_len"]);
    const query_head_dims = to_int_list(this.encoder_meta["query_head_dims"]);
    const value_head_dims = to_int_list(this.encoder_meta["value_head_dims"]);
    const num_heads = to_int_list(this.encoder_meta["num_heads"]);

    if (this.ml_context && this.tensors_created) {
      // zero out the tensors, no need to create new ones, just reuse the existing ones
      // creation of the tensors is expensive

      function zero_out_tensors(ml_context, tensors) {
        perf.time("icefall_zip2.zero_out_states");
        for (const name in tensors) {
          zero_out_ml_tensor(ml_context, tensors[name]);
        }

        perf.timeEnd("icefall_zip2.zero_out_states");
      }

      zero_out_tensors(this.ml_context, this.inputs);
      zero_out_tensors(this.ml_context, this.outputs0);
      zero_out_tensors(this.ml_context, this.outputs1);
    } else {
      // create the tensors
      this.outputs0 = await this._create_output_tensors(
        query_head_dims,
        num_heads,
        encoder_dims,
        value_head_dims,
        cnn_module_kernels,
        left_context_len
      );
      this.outputs1 = await this._create_output_tensors(
        query_head_dims,
        num_heads,
        encoder_dims,
        value_head_dims,
        cnn_module_kernels,
        left_context_len
      );
      this.inputs = await this._create_input_tensors();

      this.tensors_created = true;
    }

    this.input_states = {};
    this.update_input_states(this.outputs0);
  }

  async _create_input_tensors() {
    perf.time("icefall_zip2._create_inputs");
    const inputs = {
      x: await onnx_tensor_zeros(
        [batch_size, this.T, 80],
        this.dtype,
        this.ml_context,
        false,
        true
      ),
      x_lens: await onnx_tensor_zeros(
        [batch_size],
        "int64",
        this.ml_context,
        false,
        true
      ),
    };

    perf.timeEnd("icefall_zip2._create_inputs");
    return inputs;
  }

  async _create_output_tensors(
    query_head_dims,
    num_heads,
    encoder_dims,
    value_head_dims,
    cnn_module_kernels,
    left_context_len
  ) {
    perf.time("icefall_zip2._create_outputs");
    let tensors = {};

    let j = 0;
    for (let i = 0; i < this.num_encoders; i++) {
      const num_layers = this.num_encoder_layers[i];
      const key_dim = query_head_dims[i] * num_heads[i];
      const embed_dim = encoder_dims[i];
      const nonlin_attn_head_dim = Math.trunc((3 * embed_dim) / 4);
      const value_dim = value_head_dims[i] * num_heads[i];
      const conv_left_pad = Math.trunc(cnn_module_kernels[i] / 2);

      for (let layer = 0; layer < num_layers; layer++) {
        tensors[`new_cached_key_${j}`] = await onnx_tensor_zeros(
          [left_context_len[i], batch_size, key_dim],
          this.dtype,
          this.ml_context,
          false,
          true
        );
        tensors[`new_cached_nonlin_attn_${j}`] = await onnx_tensor_zeros(
          [1, batch_size, left_context_len[i], nonlin_attn_head_dim],
          this.dtype,
          this.ml_context,
          false,
          true
        );
        tensors[`new_cached_val1_${j}`] = await onnx_tensor_zeros(
          [left_context_len[i], batch_size, value_dim],
          this.dtype,
          this.ml_context,
          false,
          true
        );
        tensors[`new_cached_val2_${j}`] = await onnx_tensor_zeros(
          [left_context_len[i], batch_size, value_dim],
          this.dtype,
          this.ml_context,
          false,
          true
        );
        tensors[`new_cached_conv1_${j}`] = await onnx_tensor_zeros(
          [batch_size, embed_dim, conv_left_pad],
          this.dtype,
          this.ml_context,
          false,
          true
        );
        tensors[`new_cached_conv2_${j}`] = await onnx_tensor_zeros(
          [batch_size, embed_dim, conv_left_pad],
          this.dtype,
          this.ml_context,
          false,
          true
        );
        j++;
      }
    }

    tensors[`new_embed_states`] = await onnx_tensor_zeros(
      [batch_size, 128, 3, 19],
      this.dtype,
      this.ml_context,
      false,
      true
    );
    tensors[`new_processed_lens`] = await onnx_tensor_zeros(
      [batch_size],
      "int64",
      this.ml_context,
      false,
      true
    );

    tensors[`encoder_out`] = await onnx_tensor_zeros(
      [batch_size, 8, this.joiner_dim],
      this.dtype,
      this.ml_context,
      true,
      true
    );
    tensors[`encoder_out_lens`] = await onnx_tensor_zeros(
      [batch_size],
      "int64",
      this.ml_context,
      false,
      true
    );

    perf.timeEnd("icefall_zip2._create_outputs");
    return tensors;
  }

  _build_encoder_input_output(x, x_lens) {
    this.step++;

    // let encoder_input = {
    //   x: tf2onnx(this.dtype, x),
    //   x_lens: array2onnx_int64(x_lens),
    // };

    if (this.ml_context) {
      // using webnn
      let buffer = null;

      if (this.dtype === "float16") {
        const f16 = new Float16Array(x.dataSync());
        buffer = new Uint16Array(f16.buffer);
      } else {
        buffer = x.dataSync();
      }
      this.ml_context.writeTensor(this.inputs.x.mlTensorData, buffer);

      buffer = new BigInt64Array(x_lens.map((x) => BigInt(x)));
      this.ml_context.writeTensor(this.inputs.x_lens.mlTensorData, buffer);
    } else {
      // using wasm
      this.inputs.x = tf2onnx(this.dtype, x);
      this.inputs.x_lens = array2onnx_int64(x_lens);
    }

    const encoder_input = {
      ...this.inputs,
      ...this.input_states,
    };

    const encoder_output = this.step % 2 === 0 ? this.outputs0 : this.outputs1;

    return [encoder_input, encoder_output];
  }

  update_input_states(encoder_output) {
    for (const name in encoder_output) {
      if (name.includes("new")) {
        let newName = name.replace("new_", "");
        this.input_states[newName] = encoder_output[name];
      }
    }
  }

  async run_encoder(x, x_lens) {
    perf.time("icefall_zip2.run_encoder");

    const [encoder_input, encoder_output] = this._build_encoder_input_output(
      x,
      x_lens
    );

    // console.log("encoder_input:", encoder_input);
    // console.log("encoder_output:", encoder_output);

    if (this.ml_context) {
      // using webnn
      await this.encoder.run(encoder_input, encoder_output);
      this.update_input_states(encoder_output);

      perf.timeEnd("icefall_zip2.run_encoder");

      return encoder_output["encoder_out"];
    }

    // not using webnn
    const out = await this.encoder.run(encoder_input, encoder_output.keys);
    this.update_input_states(out);
    // console.log("out:", out);

    perf.timeEnd("icefall_zip2.run_encoder");

    return out["encoder_out"];
  }

  async run_decoder(decoder_input) {
    perf.time("icefall_zip2.run_decoder");

    const out = await this.decoder.run({ y: decoder_input }, ["decoder_out"]);

    perf.timeEnd("icefall_zip2.run_decoder");

    return out["decoder_out"];
  }

  async run_joiner(encoder_out, decoder_out) {
    perf.time("icefall_zip2.run_joiner");

    let out = await this.joiner.run(
      {
        encoder_out: encoder_out,
        decoder_out: decoder_out,
      },
      ["logit"]
    );

    perf.timeEnd("icefall_zip2.run_joiner");

    return out["logit"];
  }

  async decode() {
    perf.time("icefall_zip2.decode");

    if (this.is_input_decoded) {
      perf.timeEnd("icefall_zip2.decode");
      return 0;
    }

    let [features, feature_lens] = await this.get_feature_frames();

    if (features == null && feature_lens === 0) {
      // no features ready yet
      perf.timeEnd("icefall_zip2.decode");
      return 0;
    }

    features = features.expandDims(0);
    feature_lens = [feature_lens];

    if (features.shape[1] < this.decode_chunk_len + this.pad_length) {
      const pad_length =
        this.decode_chunk_len + this.pad_length - features.shape[1];
      // console.log(pad_length);
      // features.print();
      features = tf.pad(
        features,
        [
          [0, 0],
          [0, pad_length],
          [0, 0],
        ],
        LOG_EPS
      );
      // features.print();
      // console.log(features.shape);
    }

    const encoder_out = await this.run_encoder(features, feature_lens);
    // console.log(encoder_out.dims);

    if (this.params.decoding_method === "greedy_search") {
      this.hyp = await this.greedy_search(encoder_out, this.hyp);
    } else if (this.params.decoding_method === "modified_beam_search") {
      this.hyps = await this.modified_beam_search(
        encoder_out,
        this.hyps,
        this.params.num_active_paths
      );
    } else {
      throw new Error(`Invalid decoding_method ${this.params.decoding_method}`);
    }

    perf.timeEnd("icefall_zip2.decode");

    return 1;
  }
}

export { IcefallZip2 };
