#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @File     : eval_classification
# @Author   : LiuYan
# @Time     : 2021/4/20 21:19

from collections import Counter
from sklearn.metrics import precision_score, recall_score, f1_score

from base.evaluation.base_evaluator import BaseEvaluator


class ClassifyEvaluator(BaseEvaluator):

    # def __init__(self, label_dict: dict):
    def __init__(self):
        super(ClassifyEvaluator, self).__init__()
        # self._label_dict = label_dict
        # self._count_dict = {'TP': 0}
        pass

    def evaluate(self, true_list: list, pred_list: list) -> tuple:
        dict_result = {}
        true_labels = Counter(true_list)
        pred_labels = Counter(pred_list)
        print(true_labels)
        print(pred_labels)
        for true_label in true_labels:
            # print(true_labels[true_label], pred_labels[true_label])
            dict_result[true_label] = {
                'precision': 0,
                'recall': 0,
                'f1-score': 0,
                'true_num': 0,
                'pred_num': pred_labels[true_label],
                'total_num': true_labels[true_label]
            }

        for true, pred in zip(true_list, pred_list):
            if true == pred:
                dict_result[true]['true_num'] += 1

        print('\n' + ''.join('-' for i in range(89)))
        print('label_type\t\t\tp\t\t\tr\t\t\tf1\t\t\ttrue_num\t\t\tpred_num\ttotal_num')
        string = '{0}{1:<12.4f}{2:<12.4f}{3:<12.4f}{4:<12}{5:<12}{6:<12}'
        true_nums, pred_nums, total_nums = 0, 0, 0
        for label_type in dict_result:
            true_nums += dict_result[label_type]['true_num']
            pred_nums += dict_result[label_type]['pred_num']
            total_nums += dict_result[label_type]['total_num']
            p = dict_result[label_type]['true_num'] / dict_result[label_type]['pred_num'] if dict_result[label_type]['pred_num'] != 0 else 0
            r = dict_result[label_type]['true_num'] / dict_result[label_type]['total_num'] if dict_result[label_type]['total_num'] != 0 else 0
            f1 = 2 * p * r / (p + r) if p + r != 0 else 0
            chunk_type_out = label_type + ''.join(
                ' ' for i in range(20 - (((len(label_type.encode('utf-8')) - len(label_type)) // 2) + len(label_type)))
            )
            print(string.format(chunk_type_out, p, r, f1, dict_result[label_type]['true_num'],
                             dict_result[label_type]['pred_num'], dict_result[label_type]['total_num']), chr(12288))
            dict_result[label_type]['precision'] = p
            dict_result[label_type]['recall'] = r
            dict_result[label_type]['f1-score'] = f1
        p = true_nums / pred_nums if pred_nums != 0 else 0
        r = true_nums / total_nums if total_nums != 0 else 0
        f1 = 2 * p * r / (p + r) if p + r != 0 else 0
        print(string.format('average{}'.format(''.join(' ' for i in range(13))), p, r, f1,
                            true_nums, pred_nums, total_nums), chr(12288))
        print(''.join('-' for i in range(89)) + '\n')
        dict_result['average'] = {
            'precision': p,
            'recall': r,
            'f1-score': f1,
            'true_num': true_nums,
            'pred_num': pred_nums,
            'total_num': total_nums
        }

        return p, r, f1, dict_result
