#!/usr/bin/env python 
# -*- coding: utf-8 -*-
# @File    : test_br_pro_risk_recognition.py
# @Time    : 2022/1/5 18:09
# @Author  : Mr.Ygg
# @Software: PyCharm

from base.app.base_app import *
from classification.runner.runner_fast_text import FastTextRunner
from classification.utils.utils import load_risk_keywords, is_include_compound_words

# 风险分类
risk_info = [
    '外部政治风险',
    '主权政治风险',
    '社会动荡风险',
    '对华关系风险',
    '资金风险',
    '财政风险',
    '汇率风险'
    '通货膨胀风险',
    '环保风险',
    '法律风险',
    '突发事件风险',
    '项目实施风险',
    '企业风险',
    '其他风险'
]
ft_config_path = '../config/config_br_pro_risk_recognition.yml'
runner = FastTextRunner(config_path=ft_config_path)

# 招聘股票筛选模型
ft_config_path_rc_f_zp_gp = '../config/config_rc_f_zp_gp.yml'
runner_rc_f_zp_gp = FastTextRunner(config_path=ft_config_path_rc_f_zp_gp)
# 项目资讯正负面信息分析模型
ft_config_path_psa = '../config/config_br_pro_sentiment_analysis.yml'
runner_psa = FastTextRunner(config_path=ft_config_path_psa)

list_country = []
with open('../config/country.txt', 'r', encoding='utf-8') as f:
    lines = f.readlines()
for line in lines:
    list_country.append(line.strip().split('(')[0].split('（')[0])

# 模型可识别的风险类型
risk_model_info = [
    '社会动荡风险',
    '突发事件风险'
]
# 风险分类关键词
dict_risk_keywords = load_risk_keywords('../config/risk_keywords.xlsx')


def pred(title: str, content: str) -> dict:
    dict_result = {
        '风险类别1': '',
        '风险类别2': '',
        '风险类别3': '',
        '风险类别4': ''
    }

    # 招聘股票筛选模型
    result_rc_f_zp_gp = runner_rc_f_zp_gp.pred(title=title, content=content)
    # 0: 非招聘股票 1: 招聘信息 2: 股票信息
    bool_rc_f_zp_gp = False if result_rc_f_zp_gp == '1' else True
    logger.info('招聘股票筛选模型: {}'.format(result_rc_f_zp_gp))
    logger.info('招聘股票筛选模型: {}'.format(bool_rc_f_zp_gp))
    # 正负面筛选模型
    result_psa = runner_psa.pred(title=title, content=content)
    bool_psa = True if result_psa == '项目负面资讯信息' else False
    logger.info('正负面筛选模型: {}'.format(result_psa))
    logger.info('正负面筛选模型: {}'.format(bool_psa))

    # 国家识别筛选模型
    bool_country = False
    text = title + '。' + content[: len(content) // 5]
    for country in list_country:
        if country in text:
            bool_country = True
            logger.info('国家识别筛选模型: {}'.format(country))
            break
    logger.info('国家识别筛选模型: {}'.format(bool_country))

    text = title + '。' + content
    if bool_country and bool_psa:
        """
        1. 招聘股票筛选模型 -> 非招聘股票信息
        2. 国家识别筛选模型 -> 一带一路相关国家
        3. 正负面筛选模型 -> 负面信息
        """
        # 风险识别筛选模型
        result = runner.pred(
            title=title,
            content=content
        )
        dict_result['风险类别1'] = result
        dict_result['风险类别2'] = result
        dict_result['风险类别3'] = result
        dict_result['风险类别4'] = result
        logger.info('风险识别筛选模型: {}'.format(result))
        # 基于关键词的筛选模型
        if type(result) is str and result in risk_model_info:
            # risk_model_info所包含的风险类别需按照关键词筛选掉一些脏数据
            bool_risk_keyword = False
            for risk_keyword in dict_risk_keywords[result]:
                compound_words = risk_keyword.split('+')
                if is_include_compound_words(text=text, compound_words=compound_words):
                    bool_risk_keyword = True
                    break

            result = result if bool_risk_keyword else '无风险'
            dict_result['风险类别3'] = result
            dict_result['风险类别4'] = result
            logger.info('关键词筛选: {}'.format(bool_risk_keyword))
            if result == '无风险':
                dict_risk_keywords_num = {
                    risk_keywords_key: 0 for risk_keywords_key in dict_risk_keywords
                }
                bool_risk_keyword, risk_category = False, result
                for risk_keywords_key in dict_risk_keywords_num:
                    for risk_keyword in dict_risk_keywords[risk_keywords_key]:
                        compound_words = risk_keyword.split('+')
                        if is_include_compound_words(text=text, compound_words=compound_words):
                            bool_risk_keyword = True
                            dict_risk_keywords_num[risk_keywords_key] += 1

                if bool_risk_keyword:
                    risk_category = max(dict_risk_keywords_num, key=dict_risk_keywords_num.get)

                dict_result['风险类别3'] = risk_category
                logger.info('关键词筛选后召回风险信息: {}'.format(risk_category))
        elif type(result) is str and result == '无风险':
            # 模型识别为无风险的信息，采用关键词召回一些有用的风险信息
            dict_risk_keywords_num = {
                risk_keywords_key: 0 for risk_keywords_key in dict_risk_keywords
            }
            # 不召回模型能识别的风险类别？ √
            for risk_keywords_key in risk_model_info:
                dict_risk_keywords_num.pop(risk_keywords_key) if risk_keywords_key in dict_risk_keywords_num else None

            bool_risk_keyword, risk_category = False, result
            for risk_keywords_key in dict_risk_keywords_num:
                for risk_keyword in dict_risk_keywords[risk_keywords_key]:
                    compound_words = risk_keyword.split('+')
                    if is_include_compound_words(text=text, compound_words=compound_words):
                        bool_risk_keyword = True
                        dict_risk_keywords_num[risk_keywords_key] += 1

            if bool_risk_keyword:
                risk_category = max(dict_risk_keywords_num, key=dict_risk_keywords_num.get)

            dict_result['风险类别2'] = risk_category
            dict_result['风险类别3'] = risk_category
            dict_result['风险类别4'] = risk_category
            logger.info('关键词召回风险信息: {}'.format(risk_category))
        else:
            result = result if type(result) is str else 'error'
            dict_result['风险类别3'] = result
            dict_result['风险类别4'] = result
            logger.info('ELSE 风险信息: {}'.format(result))
    else:
        dict_result['风险类别1'] = '无风险'
        dict_result['风险类别2'] = '无风险'
        dict_result['风险类别3'] = '无风险'
        dict_result['风险类别4'] = '无风险'
        logger.info('招聘股票|国家识别筛选: 无风险')

    return dict_result


if __name__ == '__main__':
    import os
    import pandas
    root_dir = '../data/datasource/test'
    # file_name = 'br总资讯'
    file_name = '境外快讯_1.4'
    df = pandas.read_excel(os.path.join(root_dir, 'input_file/{}.xlsx'.format(file_name)))
    df.drop_duplicates(subset='标题', keep='first', inplace=True)
    list_title = df['标题']
    list_content = df['正文']
    dict_risk_result = {
        '风险类别1': [],
        '风险类别2': [],
        '风险类别3': [],
        '风险类别4': []
    }
    list_risk, list_risk_old = [], []
    for index, (title, content) in enumerate(zip(list_title, list_content)):
        dict_result = pred(title=title, content=content)
        for key in dict_risk_result:
            dict_risk_result[key].append(dict_result[key] if key in dict_result else 'error')

        result_old = runner.pred(title=title, content=content)
        list_risk_old.append(result_old)
        logger.info('{} / {}\n'.format(index + 1, len(list_title)))

    df['风险类别_old'] = list_risk_old
    for key in dict_risk_result:
        df[key] = dict_risk_result[key]

    df.to_excel(os.path.join(root_dir, 'output_file/{}_result_20220112_s.xlsx'.format(file_name)))