#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time    : 2021/5/11 16:30
# @Author  : 程婷婷
# @FileName: BaseEvaluator.py
# @Software: PyCharm
from sklearn.metrics import precision_score, f1_score, recall_score, classification_report
import logging
from base.views.config.BaseConfig import BaseConfig
formats = '%(asctime)s %(levelname)s %(pathname)s %(funcName)s %(message)s'
logging.basicConfig(format=formats, level=logging.INFO)


class BaseEvaluator:
    def __init__(self, config_path):
        self.evaluate_config = BaseConfig(config_path)._parsed_file['evaluate']

    def evaluate(self, y_true, y_pred, label_mapping, logger):
        result = []
        y_true = list(map(str, y_true))
        y_pred = list(map(str, y_pred))
        logger.info('模型评估结果如下：')
        if not label_mapping:
            result.append(classification_report(y_true, y_pred))
            logger.info(classification_report(y_true, y_pred))
        else:
            for value in label_mapping.values():
                print([k for k,v in label_mapping.items() if v == value])
                p = precision_score(y_true, y_pred, average=self.evaluate_config['average'], pos_label=str(value))
                r = recall_score(y_true, y_pred, average=self.evaluate_config['average'], pos_label=str(value))
                f1 = f1_score(y_true, y_pred, average=self.evaluate_config['average'], pos_label=str(value))
                print({'value': value,'召回率为': r, '精确率为': p, 'F1': f1})
                logger.info('标签为%s' % [k for k,v in label_mapping.items() if v == value][0])
                logger.info('精确率为%.2f' %p)
                logger.info('召回率为%.2f' %r)
                logger.info('精确率为%.2f' %f1)
                result.append(str({'label': value,'recall': r, 'precision': p, 'F1': f1}))
        return ' '.join(result)

# y_true = [0, 1, 2, 0, 1, 2]
# y_pred = [0, 2, 1, 0, 0, 1]
# print(BaseEvaluator())
