#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @File    : 矩阵测试.py
# @Time    : 2022/12/12 19:23
# @Author  : bruxelles_li
# @Software: PyCharm
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
from bert_serving.client import BertClient
from tqdm import tqdm
import numpy as np
from numpy import *
import datetime
import logging
from es_byid import find_sent_info, find_para_info, find_art_info, find_sen_content
# '114.115.130.239',
bc = BertClient(check_length=False)
prob = 0.85
# file_path = "素材库/句子库/句子库.npy"
# vector_path = "测试文件/sent.txt"
np_path = "database/sent_database/other_sent.npy"
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] <%(processName)s> (%(threadName)s) %('
                                               'message)s')
logger = logging.getLogger(__name__)
# b = np.load(file_path)
# df = pd.read_excel("测试文件/句子库测试样例.xlsx", keep_default_na=False).astype(str)
# length = len(df)
# b = np.array([[1,2,3,4],[5,6,7,8],[11,12,13,14],[2,3,4,5]])
# print(b.shape[0], b.shape[1])
# print(b.shape)
# print(b)
# c = b.transpose()
# print(c)
# print(c[1::])
# d = c[1::].transpose()
# print(d)
#
# print(c[0], type(c[0]))
# d = c[0].tolist()
# print(d, type(d))


def encode_sentences(vector_path, df, length, np_path):
    with open(vector_path, 'w', encoding='utf-8') as f_vectors:
        for idx, row in tqdm(df.iterrows()):
            sentence = row['content']
            vector = bc.encode([sentence])
            # print(vector)
            f_vectors.write(str(row['id']) + ' ' + ' '.join(map(str, list(vector[0]))) + '\n')
    f_vectors.close()
    save_file(length, vector_path, np_path)

    return None


def save_file(length, vector_path, np_path):
    A = zeros((int(length), 769), dtype=float)
    f = open(vector_path)
    lines = f.readlines()
    A_row = 0
    for line in lines:
        list = line.strip('\n').split(' ')
        A[A_row, :] = list[:]
        A_row += 1
    print(A.shape)
    np.save(np_path, A)


def get_result(text, np_arrary):
    # 导入初始矩阵
    b = np_arrary
    a = bc.encode([text])
    # todo: 考虑当数据量在4g时，矩阵计算时间超过4秒，先将矩阵进行切片后计算， 当满足条件的内容长度大于30时不进行后续计算
    start0_time = datetime.datetime.now()
    sub_arrarys = np.array_split(b, 800)
    sim_result = []
    id_result = []
    for x in sub_arrarys:
        if len(sim_result) < 30:
            # todo: 将初始矩阵转换为目标矩阵，通过先转置，后按行切片获得目标子矩阵,然后对子矩阵再次转置得到
            c = x.transpose()
            d = c[1::].transpose()
            # todo: 此时，id_list(对应从0-N的矩阵索引)可根据第一次转置后的第一行获得
            id_list = c[0].tolist()
            # 根据行长度初始化矩阵索引np_list
            np_list = [n for n in range(x.shape[0])]
            id_dict = dict(zip(np_list, id_list))
            r = cosine_similarity(a, d)
            target = np.where(r >= 0.85)
            column_list = target[1].tolist()
            if column_list:
                id_list = [str(id_dict[i]).split(".")[0] for i in column_list]
                sim_list = r[target].tolist()
                sim_result.extend(sim_list)
                id_result.extend(id_list)

        else:
            break
    end0_time = datetime.datetime.now()
    total0_time = (end0_time - start0_time).total_seconds()
    logger.info(len(id_result))
    logger.info("拆分矩阵计算 共消耗: " + "{:.2f}".format(total0_time) + " 秒")

    # todo: 将初始矩阵转换为目标矩阵，通过先转置，后按行切片获得目标子矩阵,然后对子矩阵再次转置得到
    start1_time = datetime.datetime.now()
    c = b.transpose()
    d = c[1::].transpose()
    # todo: 此时，id_list(对应从0-N的矩阵索引)可根据第一次转置后的第一行获得
    id_list = c[0].tolist()
    # 根据行长度初始化矩阵索引np_list
    np_list = [n for n in range(b.shape[0])]
    id_dict = dict(zip(np_list, id_list))
    a = bc.encode([text])
    r = cosine_similarity(a, d)
    target = np.where(r >= 0.85)
    column_list = target[1].tolist()
    id_list = [str(id_dict[i]).split(".")[0] for i in column_list]
    sim_list = r[target].tolist()
    logger.info(len(id_list))
    end1_time = datetime.datetime.now()
    total1_time = (end1_time - start1_time).total_seconds()
    logger.info("全矩阵计算 共消耗: " + "{:.2f}".format(total1_time) + " 秒")
    # # print(sim_list)
    # df1 = pd.DataFrame({
    #     "id": id_list,
    #     "sim": sim_list
    # })
    # # print(df1)
    # test = df1.sort_values(by=['sim'], axis=0, ascending=False)
    # # todo: 场景1 ->不勾选主题参数
    # topicTypeNames = []
    # if len(topicTypeNames) == 0:
    #     df2 = test[:10]
    # # todo： 场景2 ->勾选主题参数
    # else:
    #     df2 = test[:30]
    #
    # # todo: 先取唯一标识id，并调用es查询获取匹配信息
    # new_id_list = df2["id"].tolist()
    # info_df = find_sent_info(new_id_list)
    #
    # # todo: 将匹配信息进行整合，包括df2 + info_df
    # temp_df = pd.merge(df2, info_df, on="id")
    # result = []
    # for idx, row in tqdm(temp_df.iterrows()):
    #     sentence_id = row["sentenceId"]
    #     sent_article_id = row["articleId"]
    #     sent_content = row["content"]
    #     # todo： 根据段落所在的文章id获取文章信息
    #     art_temp_result = find_art_info(sent_article_id)
    #     title = art_temp_result["articleTitle"]
    #     origin = art_temp_result["origin"]
    #     time = art_temp_result["articleTime"]
    #     author = art_temp_result["author"]
    #     article_content = art_temp_result["content"]
    #
    #     # todo: 根据sentence_id 和 sent_article_id 获取前后句
    #     final_content = find_sen_content(sent_article_id, sentence_id, sent_content)
    #     # todo： 场景1 ->勾选主题参数，判断主题和时间范围
    #     if topicTypeNames:
    #         if row["topicType"] in topicTypeNames:
    #             result.append({
    #                 "content": final_content,
    #                 "similarity": round(row['sim'], 4),
    #                 "id": row["id"],
    #                 "article_id": sent_article_id,
    #                 "paragraphid": row["paragraphId"],
    #                 "match_index": row["sentParaIndex"] + ";" + row["sentArticleIndex"],
    #                 "topic_type": row["topicType"],
    #                 "content_type_name": row["contentTypeName"],
    #                 "article_content": article_content,
    #                 "publishDate": time,
    #                 "author": author,
    #                 "origin": origin,
    #                 "title": title,
    #                 # "type": returenType
    #             })
    #
    #     # todo: 场景2 -> 不勾选类型参数， 仅判断事件范围
    #     else:
    #         result.append({
    #             "content": final_content,
    #             "similarity": round(row['sim'], 4),
    #             "id": row["id"],
    #             "article_id": sent_article_id,
    #             "paragraphid": row["paragraphId"],
    #             "match_index": row["sentParaIndex"] + ";" + row["sentArticleIndex"],
    #             "topic_type": row["topicType"],
    #             "content_type_name": row["contentTypeName"],
    #             "article_content": article_content,
    #             "publishDate": time,
    #             "author": author,
    #             "origin": origin,
    #             "title": title,
    #             # "type": returenType
    #         })
    #
    # final_df = pd.DataFrame(result)
    # final_df.to_excel("测试文件/test.xlsx", engine="xlsxwriter", index=False)


    return None


if __name__ == "__main__":
    test_list = [
                    {
                        "create_time": "2023-01-03 18:02:24",
                        "para_id": "1",
                        "infoId": "123",
                        "para_index": "2|538",
                        "para_content": "强化创新引领 加快“三个转变” 更好推动中国制造高质量发展——国资委党委委员、副主任 翁杰明。",
                        "contentTypeIds": "1602095566267805697",
                        "contentNames": "领导讲话",
                        "topicNames": "产业链链长",
                        "type": "par",
                        "repeatedId": "1670843543885716",
                        "is_main": "0"
                    },
                    {
                        "create_time": "2023-01-03 18:02:24",
                        "para_id": "2",
                        "infoId": "1234",
                        "para_index": "",
                        "para_content": "强化创新引领 加快“三个转变” 更好推动中国制造高质量发展——国资委党委委员、副主任 翁杰明。",
                        "contentTypeIds": "1602095566267805698",
                        "contentNames": "领导讲话",
                        "topicNames": "产业链链长",
                        "type": "par",
                        "repeatedId": "1670843543885715",
                        "is_main": "0"
                    }
                ]
    start0_time = datetime.datetime.now()
    np_arrary = np.load(np_path)
    end0_time = datetime.datetime.now()
    total0_time = (end0_time - start0_time).total_seconds()
    logger.info("加载矩阵 共消耗: " + "{:.2f}".format(total0_time) + " 秒")
    # get_result("张文魁:应进一步分行业设立国企负债率警戒线和监管线", np_arrary)

    # text = "张文魁:应进一步分行业设立国企负债率警戒线和监管线"
    # print(get_result(text, file_path))
    # # todo: 初始化list
    # list1 = [n for n in range(10)]
    # list2 = [n for n in range(11, 21)]
    # dict0 = dict(zip(list1, list2))
    # print(list1, list2)
    # print(dict0)
    # encode_sentences(vector_path, df, length, np_path)
