import {pipeline} from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.1.0';
import {env} from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.1.0';
import mixedCase from './artifacts/mixed_case_dict.json'

env.localModelPath = `${process.env.PUBLIC_URL}/`;
env.allowRemoteModels = false;
env.allowLocalModels = true;

class CapitalizerPipeline {
    task = 'token-classification';
    model = 'distilbert';
    pipeline = null;

    tld_names = new Set(["org", "com", "net", "info", "edu", "gov", "mil", "arpa"]);

    // dtype can be one of
    //         "fp32",
    //         "fp16",
    //         "q8",
    //         "int8",
    //         "uint8",
    //         "q4",
    //         "q4f16",
    //         "bnb4"
    async init(progress_callback = null, device = 'wasm', dtype = null) {
        if (this.pipeline === null) {
            this.pipeline = await pipeline(
                this.task,
                this.model,
                {
                    progress_callback,
                    device: device,
                    dtype: dtype
                }
            );
        }
    }

    async capitalize(text) {
        if (this.pipeline === null) {
            await this.init();
        }

        // noinspection JSValidateTypes
        const classifier_out = await this.pipeline(text);
        const offsets = this._get_offsets_mapping(
            [classifier_out.map(({word}) => word)],
            [text]
        )[0];

        const [words, wordOffsets] = this._merge_words(classifier_out, offsets);

        let output = text;
        words.forEach((word, i) => {
            let offset = wordOffsets[i];
            if (offset[0] !== offset[1]) {
                switch (word.entity) {
                    case 'B-a':
                    case 'B-undef':
                        return
                    case 'B-A':
                        if (this.tld_names.has(word.word) && i - 2 >= 0 && words[i - 1].word === '.'
                            && wordOffsets[i - 2][1] === offset[0] - 1) {
                            // don't capitalize tokens that look like tld names
                            break;
                        }
                        output = this._uppercase_string(output, offset[0], offset[0] + 1);
                        break;
                    case 'B-AAA':
                        output = this._uppercase_string(output, offset[0], offset[1]);
                        break;
                    case 'B-aAa':
                        if (word.word in mixedCase) {
                            output = this._replace_substring(output, mixedCase[word.word], offset[0], offset[1]);
                        }
                        break;
                }
            }
        })
        return output;
    }

    _uppercase_string(text, idx_start, idx_end) {
        return text.slice(0, idx_start) + text.slice(idx_start, idx_end).toUpperCase() + text.slice(idx_end);
    }

    _replace_substring(text, replace_text, idx_start, idx_end) {
        return text.slice(0, idx_start) + replace_text + text.slice(idx_end);
    }

    _merge_words(tokens, offsets) {
        let tokensOut = [];
        let offsetsOut = [];

        tokens.forEach((token, i) => {
            let offset = offsets[i];
            if (offset[0] !== offset[1]) {
                tokensOut.push(token);
                offsetsOut.push(offset);
            } else {
                if (tokensOut.length > 0) {
                    let wordpiece = token.word.replaceAll('#', '');
                    tokensOut[tokensOut.length - 1].word = tokensOut[tokensOut.length - 1].word + wordpiece;
                    offsetsOut[offsetsOut.length - 1][1] = offsetsOut[offsetsOut.length - 1][1] + wordpiece.length;
                    offsetsOut[offsetsOut.length - 1][2] = offsetsOut[offsetsOut.length - 1][2] + wordpiece;
                }
            }
        })
        return [tokensOut, offsetsOut];
    }

    /**
     * Estimate offsets mapping from original context string
     * @param {BatchEncoding|string|string[]|string[][]} search Object with input_ids from tokenizer, array[][] tokens, or space delimited string tokens
     * @param {string|string[]} context
     * @param {string} strategy 'none' or 'closest'
     * @param {boolean} caseSensitive
     * @returns {any[]} (char_start, char_end, token)
     */
    _get_offsets_mapping(search, context, strategy = 'none', caseSensitive = false) {
        let toReturn = [],
            idx, lastIdx, len;

        if (typeof search == 'object' && 'input_ids' in search) {
            search = search.input_ids.tolist();
        } else {
            if (!Array.isArray(search)) search = [search];
            if (typeof search[0] == 'string') {
                search.forEach((val, key) => {
                    search[key] = val.split(' ');
                })
            }
        }
        if (typeof context == 'string') context = [context];
        if (!caseSensitive) {
            context.forEach((val, key) => {
                context[key] = val.toLowerCase();
            })
        }

        search.forEach((tokens, i) => {
            toReturn.push([]);
            lastIdx = 0;
            if (typeof tokens[0] != 'string') {
                if ('input_ids' in tokens) tokens = tokens.input_ids
                tokens = this.pipeline.tokenizer.model.convert_ids_to_tokens(tokens);
            }

            tokens.forEach(token => {
                idx = context[i].indexOf(caseSensitive ? token : token.toLowerCase(), lastIdx);

                // look behind and find closest match
                if (strategy == 'closest' && idx >= 0) {
                    let a, strStart, strSearch, strEncodings, strTokens, strIdx, lastIdx;

                    lastIdx = idx;
                    for (a = toReturn.at(-1).length - 1; a >= 0; a--) {
                        strStart = a > 0 ? lastIdx - 1 - (toReturn[i][a][0] - toReturn[i][a - 1][0]) : 0;
                        strSearch = context[i].substring(strStart, idx);

                        strTokens = [];
                        strEncodings = this.pipeline.tokenizer._call(strSearch, {
                            return_offsets_mapping: true
                        })
                        strEncodings.offset_mapping[0].forEach(offset => {
                            strTokens.push(offset[2]);
                        })

                        strIdx = strTokens.lastIndexOf(toReturn[i][a][2]);
                        if (strIdx >= 0) {
                            strIdx = strEncodings.offset_mapping[0][strIdx][0];
                            lastIdx = strStart + strIdx;
                            toReturn[i][a] = [lastIdx, lastIdx + toReturn[i][a][2].length, toReturn[i][a][2]];
                        } else break;
                    }

                }

                if (idx < 0) {
                    idx = lastIdx;
                    len = 0;
                } else len = token.length;

                toReturn.at(-1).push([idx, idx + len, token]);

                lastIdx = idx + len;
            })

        })

        return toReturn;
    }
}

export {CapitalizerPipeline}