/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.util.infotheory.impl;

import com.oracle.labs.mlrg.olcut.util.MutableLong;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.tribuo.util.infotheory.impl.CachedPair;

public class PairDistribution<T1, T2> {
    public final long count;
    public final Map<CachedPair<T1, T2>, MutableLong> jointCounts;
    public final Map<T1, MutableLong> firstCount;
    public final Map<T2, MutableLong> secondCount;

    public PairDistribution(long count, Map<CachedPair<T1, T2>, MutableLong> jointCounts, Map<T1, MutableLong> firstCount, Map<T2, MutableLong> secondCount) {
        this.count = count;
        this.jointCounts = new LinkedHashMap<CachedPair<T1, T2>, MutableLong>(jointCounts);
        this.firstCount = new LinkedHashMap<T1, MutableLong>(firstCount);
        this.secondCount = new LinkedHashMap<T2, MutableLong>(secondCount);
    }

    public PairDistribution(long count, LinkedHashMap<CachedPair<T1, T2>, MutableLong> jointCounts, LinkedHashMap<T1, MutableLong> firstCount, LinkedHashMap<T2, MutableLong> secondCount) {
        this.count = count;
        this.jointCounts = jointCounts;
        this.firstCount = firstCount;
        this.secondCount = secondCount;
    }

    public static <T1, T2> PairDistribution<T1, T2> constructFromLists(List<T1> first, List<T2> second) {
        LinkedHashMap<CachedPair<T1, T2>, MutableLong> abCountDist = new LinkedHashMap<CachedPair<T1, T2>, MutableLong>(20);
        LinkedHashMap<Object, MutableLong> aCountDist = new LinkedHashMap<Object, MutableLong>(20);
        LinkedHashMap<Object, MutableLong> bCountDist = new LinkedHashMap<Object, MutableLong>(20);
        if (first.size() == second.size()) {
            long count = 0L;
            for (int i = 0; i < first.size(); ++i) {
                T1 a = first.get(i);
                T2 b = second.get(i);
                CachedPair<T1, T2> pair = new CachedPair<T1, T2>(a, b);
                MutableLong abCount = abCountDist.computeIfAbsent(pair, k -> new MutableLong());
                abCount.increment();
                MutableLong aCount = aCountDist.computeIfAbsent(a, k -> new MutableLong());
                aCount.increment();
                MutableLong bCount = bCountDist.computeIfAbsent(b, k -> new MutableLong());
                bCount.increment();
                ++count;
            }
            return new PairDistribution<T1, T2>(count, abCountDist, aCountDist, bCountDist);
        }
        throw new IllegalArgumentException("Counting requires arrays of the same length. first.size() = " + first.size() + ", second.size() = " + second.size());
    }

    public static <T1, T2> PairDistribution<T1, T2> constructFromMap(Map<CachedPair<T1, T2>, MutableLong> jointCount) {
        HashMap aCount = new HashMap(20);
        HashMap bCount = new HashMap(20);
        return PairDistribution.constructFromMap(jointCount, aCount, bCount);
    }

    public static <T1, T2> PairDistribution<T1, T2> constructFromMap(Map<CachedPair<T1, T2>, MutableLong> jointCount, int aSize, int bSize) {
        HashMap aCount = new HashMap(aSize);
        HashMap bCount = new HashMap(bSize);
        return PairDistribution.constructFromMap(jointCount, aCount, bCount);
    }

    public static <T1, T2> PairDistribution<T1, T2> constructFromMap(Map<CachedPair<T1, T2>, MutableLong> jointCount, Map<T1, MutableLong> aCount, Map<T2, MutableLong> bCount) {
        long count = 0L;
        for (Map.Entry<CachedPair<T1, T2>, MutableLong> e : jointCount.entrySet()) {
            CachedPair<T1, T2> pair = e.getKey();
            long curCount = e.getValue().longValue();
            Object a = pair.getA();
            Object b = pair.getB();
            MutableLong curACount = aCount.computeIfAbsent(a, k -> new MutableLong());
            curACount.increment(curCount);
            MutableLong curBCount = bCount.computeIfAbsent(b, k -> new MutableLong());
            curBCount.increment(curCount);
            count += curCount;
        }
        return new PairDistribution<T1, T2>(count, jointCount, aCount, bCount);
    }
}

