/*
 * Decompiled with CFR 0.152.
 */
package com.hankcs.hanlp.model.maxent;

import com.hankcs.hanlp.collection.dartsclone.Pair;
import com.hankcs.hanlp.collection.trie.DoubleArrayTrie;
import com.hankcs.hanlp.corpus.io.ByteArray;
import com.hankcs.hanlp.model.maxent.Context;
import com.hankcs.hanlp.model.maxent.EvalParameters;
import com.hankcs.hanlp.model.maxent.UniformPrior;
import com.hankcs.hanlp.utility.Predefine;
import com.hankcs.hanlp.utility.TextUtility;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import java.util.TreeMap;

public class MaxEntModel {
    int correctionConstant;
    double correctionParam;
    UniformPrior prior;
    protected String[] outcomeNames;
    EvalParameters evalParams;
    DoubleArrayTrie<Integer> pmap;

    public final double[] eval(String[] context) {
        return this.eval(context, new double[this.evalParams.getNumOutcomes()]);
    }

    public final List<Pair<String, Double>> predict(String[] context) {
        ArrayList<Pair<String, Double>> result = new ArrayList<Pair<String, Double>>(this.outcomeNames.length);
        double[] p = this.eval(context);
        int i = 0;
        while (i < p.length) {
            result.add(new Pair<String, Double>(this.outcomeNames[i], p[i]));
            ++i;
        }
        return result;
    }

    public final Pair<String, Double> predictBest(String[] context) {
        List<Pair<String, Double>> resultList = this.predict(context);
        double bestP = -1.0;
        Pair<String, Double> bestPair = null;
        for (Pair<String, Double> pair : resultList) {
            if (!(pair.getSecond() > bestP)) continue;
            bestP = pair.getSecond();
            bestPair = pair;
        }
        return bestPair;
    }

    public final List<Pair<String, Double>> predict(Collection<String> context) {
        return this.predict(context.toArray(new String[0]));
    }

    public final double[] eval(String[] context, double[] outsums) {
        assert (context != null);
        int[] scontexts = new int[context.length];
        int i = 0;
        while (i < context.length) {
            Integer ci = this.pmap.get(context[i]);
            scontexts[i] = ci == null ? -1 : ci;
            ++i;
        }
        this.prior.logPrior(outsums);
        return MaxEntModel.eval(scontexts, outsums, this.evalParams);
    }

    public static double[] eval(int[] context, double[] prior, EvalParameters model) {
        Context[] params = model.getParams();
        int[] numfeats = new int[model.getNumOutcomes()];
        double value = 1.0;
        int ci = 0;
        while (ci < context.length) {
            if (context[ci] >= 0) {
                Context predParams = params[context[ci]];
                int[] activeOutcomes = predParams.getOutcomes();
                double[] activeParameters = predParams.getParameters();
                int ai = 0;
                while (ai < activeOutcomes.length) {
                    int oid;
                    int n = oid = activeOutcomes[ai];
                    numfeats[n] = numfeats[n] + 1;
                    int n2 = oid;
                    prior[n2] = prior[n2] + activeParameters[ai] * value;
                    ++ai;
                }
            }
            ++ci;
        }
        double normal = 0.0;
        int oid = 0;
        while (oid < model.getNumOutcomes()) {
            prior[oid] = model.getCorrectionParam() != 0.0 ? Math.exp(prior[oid] * model.getConstantInverse() + (1.0 - (double)numfeats[oid] / model.getCorrectionConstant()) * model.getCorrectionParam()) : Math.exp(prior[oid] * model.getConstantInverse());
            normal += prior[oid];
            ++oid;
        }
        oid = 0;
        while (oid < model.getNumOutcomes()) {
            int n = oid++;
            prior[n] = prior[n] / normal;
        }
        return prior;
    }

    public static MaxEntModel create(String path) {
        MaxEntModel m = new MaxEntModel();
        try {
            BufferedReader br = new BufferedReader(new InputStreamReader((InputStream)new FileInputStream(path), "UTF-8"));
            DataOutputStream out = new DataOutputStream(new FileOutputStream(String.valueOf(path) + ".bin"));
            br.readLine();
            m.correctionConstant = Integer.parseInt(br.readLine());
            out.writeInt(m.correctionConstant);
            m.correctionParam = Double.parseDouble(br.readLine());
            out.writeDouble(m.correctionParam);
            int numOutcomes = Integer.parseInt(br.readLine());
            out.writeInt(numOutcomes);
            String[] outcomeLabels = new String[numOutcomes];
            m.outcomeNames = outcomeLabels;
            int i = 0;
            while (i < numOutcomes) {
                outcomeLabels[i] = br.readLine();
                TextUtility.writeString(outcomeLabels[i], out);
                ++i;
            }
            int numOCTypes = Integer.parseInt(br.readLine());
            out.writeInt(numOCTypes);
            int[][] outcomePatterns = new int[numOCTypes][];
            int i2 = 0;
            while (i2 < numOCTypes) {
                StringTokenizer tok = new StringTokenizer(br.readLine(), " ");
                int[] infoInts = new int[tok.countTokens()];
                out.writeInt(infoInts.length);
                int j = 0;
                while (tok.hasMoreTokens()) {
                    infoInts[j] = Integer.parseInt(tok.nextToken());
                    out.writeInt(infoInts[j]);
                    ++j;
                }
                outcomePatterns[i2] = infoInts;
                ++i2;
            }
            int NUM_PREDS = Integer.parseInt(br.readLine());
            out.writeInt(NUM_PREDS);
            String[] predLabels = new String[NUM_PREDS];
            m.pmap = new DoubleArrayTrie();
            TreeMap<String, Integer> tmpMap = new TreeMap<String, Integer>();
            int i3 = 0;
            while (i3 < NUM_PREDS) {
                predLabels[i3] = br.readLine();
                assert (!tmpMap.containsKey(predLabels[i3])) : "\u91cd\u590d\u7684\u952e\uff1a " + predLabels[i3] + " \u8bf7\u4f7f\u7528 -Dfile.encoding=UTF-8 \u8bad\u7ec3";
                TextUtility.writeString(predLabels[i3], out);
                tmpMap.put(predLabels[i3], i3);
                ++i3;
            }
            m.pmap.build(tmpMap);
            for (Map.Entry entry : tmpMap.entrySet()) {
                out.writeInt((Integer)entry.getValue());
            }
            m.pmap.save(out);
            Context[] params = new Context[NUM_PREDS];
            int pid = 0;
            int i4 = 0;
            while (i4 < outcomePatterns.length) {
                int[] outcomePattern = new int[outcomePatterns[i4].length - 1];
                int k = 1;
                while (k < outcomePatterns[i4].length) {
                    outcomePattern[k - 1] = outcomePatterns[i4][k];
                    ++k;
                }
                int j = 0;
                while (j < outcomePatterns[i4][0]) {
                    double[] contextParameters = new double[outcomePatterns[i4].length - 1];
                    int k2 = 1;
                    while (k2 < outcomePatterns[i4].length) {
                        contextParameters[k2 - 1] = Double.parseDouble(br.readLine());
                        out.writeDouble(contextParameters[k2 - 1]);
                        ++k2;
                    }
                    params[pid] = new Context(outcomePattern, contextParameters);
                    ++pid;
                    ++j;
                }
                ++i4;
            }
            m.prior = new UniformPrior();
            m.prior.setLabels(outcomeLabels);
            m.evalParams = new EvalParameters(params, m.correctionParam, m.correctionConstant, outcomeLabels.length);
            out.close();
        }
        catch (Exception e) {
            Predefine.logger.severe("\u4ece" + path + "\u52a0\u8f7d\u6700\u5927\u71b5\u6a21\u578b\u5931\u8d25\uff01" + TextUtility.exceptionToString(e));
            return null;
        }
        return m;
    }

    public static MaxEntModel create(ByteArray byteArray) {
        MaxEntModel m = new MaxEntModel();
        m.correctionConstant = byteArray.nextInt();
        m.correctionParam = byteArray.nextDouble();
        int numOutcomes = byteArray.nextInt();
        String[] outcomeLabels = new String[numOutcomes];
        m.outcomeNames = outcomeLabels;
        int i = 0;
        while (i < numOutcomes) {
            outcomeLabels[i] = byteArray.nextString();
            ++i;
        }
        int numOCTypes = byteArray.nextInt();
        int[][] outcomePatterns = new int[numOCTypes][];
        int i2 = 0;
        while (i2 < numOCTypes) {
            int length = byteArray.nextInt();
            int[] infoInts = new int[length];
            int j = 0;
            while (j < length) {
                infoInts[j] = byteArray.nextInt();
                ++j;
            }
            outcomePatterns[i2] = infoInts;
            ++i2;
        }
        int NUM_PREDS = byteArray.nextInt();
        String[] predLabels = new String[NUM_PREDS];
        m.pmap = new DoubleArrayTrie();
        int i3 = 0;
        while (i3 < NUM_PREDS) {
            predLabels[i3] = byteArray.nextString();
            ++i3;
        }
        Integer[] v = new Integer[NUM_PREDS];
        int i4 = 0;
        while (i4 < v.length) {
            v[i4] = byteArray.nextInt();
            ++i4;
        }
        m.pmap.load(byteArray, (Integer[])v);
        Context[] params = new Context[NUM_PREDS];
        int pid = 0;
        int i5 = 0;
        while (i5 < outcomePatterns.length) {
            int[] outcomePattern = new int[outcomePatterns[i5].length - 1];
            int k = 1;
            while (k < outcomePatterns[i5].length) {
                outcomePattern[k - 1] = outcomePatterns[i5][k];
                ++k;
            }
            int j = 0;
            while (j < outcomePatterns[i5][0]) {
                double[] contextParameters = new double[outcomePatterns[i5].length - 1];
                int k2 = 1;
                while (k2 < outcomePatterns[i5].length) {
                    contextParameters[k2 - 1] = byteArray.nextDouble();
                    ++k2;
                }
                params[pid] = new Context(outcomePattern, contextParameters);
                ++pid;
                ++j;
            }
            ++i5;
        }
        m.prior = new UniformPrior();
        m.prior.setLabels(outcomeLabels);
        m.evalParams = new EvalParameters(params, m.correctionParam, m.correctionConstant, outcomeLabels.length);
        return m;
    }

    public static MaxEntModel load(String txtPath) {
        ByteArray byteArray = ByteArray.createByteArray(String.valueOf(txtPath) + ".bin");
        if (byteArray != null) {
            return MaxEntModel.create(byteArray);
        }
        return MaxEntModel.create(txtPath);
    }
}

