#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @File    : main_server.py
# @Time    : 2023/3/31 10:31
# @Author  : bruxelles_li
# @Software: PyCharm
import logging
import requests
import threading
import sys
import time, os
import json
import pandas as pd
import glob
from pathlib import Path
sys.path.append('../')
# 关闭多余连接
s = requests.session()
s.keep_alive = False
from classification.runner.runner_fast_text import FastTextRunner_train
from detector_source import sys_info, cpu_info, mem_info
from classification.data.data_process import pro_data
# 定义日志输出格式
formatter = logging.Formatter("%(asctime)s [%(levelname)s] <%(processName)s> (%(threadName)s) %(message)s")
# 创建一个logger, 并设置日志级别
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# 创建一个handler，用于将日志输出到控制台，并设置日志级别
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
# 创建一个filehandler，用于将错误日志输出到文件，并设置日志级别
_tmp_path = os.path.dirname(os.path.abspath(__file__))
# print(_tmp_path)
_tmp_path = os.path.join(_tmp_path, 'log')
Path(_tmp_path).mkdir(parents=True, exist_ok=True)
fh = logging.FileHandler(os.path.join(_tmp_path, "main_server_error.log"))
fh1 = logging.FileHandler(os.path.join(_tmp_path, "main_server_info.log"))
fh.setLevel(level=logging.ERROR)
fh1.setLevel(level=logging.INFO)
fh.setFormatter(formatter)
fh1.setFormatter(formatter)
# 同时将日志输出到控制台和文件
logger.addHandler(ch)
logger.addHandler(fh)
logger.addHandler(fh1)
# 定义训练配置文件
train_config_path = '../classification/config/fasttext_config_train.yml'

# todo: 定义处理数据相关路径
root_path = r'../word2vec/doc_similarity/'
stop_words_path = os.path.join(root_path, 'stop_words.txt')
save_data_path = r'../datasets/classification/{}/{}/{}.txt'
file_types = ['xls', 'xlsx']
# 加载java回调接口
java_call_back_url = "http://192.168.1.82:9988/manage/algorithmModel/process/changeStatus"
# 加载端口号
port = 4005
modelName = "FastText-Model"

# TODO: 定义进程存放列表
all_thread = []


def merge_df(dataset_path):
    all_files = []
    for file_type in file_types:
        all_files.extend(glob.glob(os.path.join(dataset_path, f'*.{file_type}')))

    # 将所有文件合并到一个DataFrame中
    combined_df = pd.concat([pd.read_excel(f) for f in all_files], ignore_index=True)
    # 去除重复行
    combined_df.drop_duplicates(keep='first', inplace=True)
    return combined_df


def train_model4FastText(data_path, model_path, modelProcessId, root_dataset):
    """
    train
    :return:
    """
    combined_df = merge_df(dataset_path=root_dataset)
    # 预处理数据
    pro_data(dataFolderName=data_path, data_df=combined_df, stop_words_path=stop_words_path,
             save_data_path=save_data_path, modelName=modelName)
    logger.info("====数据预处理成功，准备进入训练阶段===")
    # 进入训练
    runner_train = FastTextRunner_train(config_path=train_config_path, model_train=True)
    runner_train.train(data_path=data_path, model_path=model_path, auto_tune_duration=300)
    dict_result = runner_train.test(data_path=data_path, model_path=model_path)
    str_dict_result = json.dumps(dict_result, ensure_ascii=False)
    logger.info(str_dict_result)
    # todo: 调用java的状态更新接口返回训练后的结果
    payload = json.dumps({
        "id": modelProcessId,
        "result": str_dict_result
    })
    # todo: 调用接口访问实施生成参数函数来生成currentTime, appId
    headers = {
        'Content-Type': 'application/json'
    }
    r1 = requests.post(url=f"{java_call_back_url}",
                       headers=headers, data=payload)

    r1_json = json.loads(r1.text)
    # print(r1_json)
    logger.info(r1_json)
    return str_dict_result


def env_eval(modelProcessId):
    # todo 获取资源相关信息(磁盘占用率、系统占用信息【超过3个为高】、CPU占用率、物理内存占用率)
    # disk_usage = disk_info()
    sys_usage = sys_info()
    cpu_usage = cpu_info()
    men_usage = mem_info()
    # todo 资源不够用时，返回 False
    if sys_usage > 10000 or cpu_usage > str(95) or men_usage > str(95):
        # todo: 调用java的状态更新接口提示资源占用过高的结果
        str_dict_result = {
            'handleMsg': 'failure',
            'isHandleSuccess': False,
            'logs': '模型训练失败！当前模型训练资源占用率过高，请检查系统占用信息【超过10个为高】、CPU占用率【超过85%为高】、物理内存占用率【超过85%为高】',
            'resultData': None
        }
        logger.info(str_dict_result)
        payload = json.dumps({
            "id": modelProcessId,
            "result": str_dict_result
        })
        headers = {
            'Content-Type': 'application/json'
        }
        r1 = requests.post(
            url=f"{java_call_back_url}",
            headers=headers, data=payload)

        r1_json = json.loads(r1.text)
        # print(r1_json)
        logger.info(r1_json)
        return False
    # todo 资源够用时，返回 True
    return True


def system_start():
    while True:
        # print("=====正在进行训练服务=====")
        headers = {
            'Content-Type': 'application/json'
        }
        r1 = requests.post(url=f'http://localhost:{int(port)}/queue_size', headers=headers)
        r1_json = json.loads(r1.text)
        # print(r1_json)
        queue_left_number = r1_json['queue_left_number']
        logger.info("当前队列任务总数：" + str(queue_left_number))
        if queue_left_number == 0:
            # logger.warning("队列为空！无可处理任务。")
            time.sleep(30)
        else:
            for i in range(queue_left_number):
                r2 = requests.post(url=f'http://localhost:{int(port)}/subject_consumer', headers=headers)
                r2_json = json.loads(r2.text)
                config_info = r2_json['data']
                logger.info(config_info)
                modelProcessId = config_info["modelProcessId"]
                model_path = config_info["model_path"]
                data_path = config_info["data_path"]
                root_dataset = config_info["root_dataset"]
                logger.info('##########FastText-Model###############')
                t = threading.Thread(target=train_model4FastText,
                                     args=(data_path, model_path, modelProcessId, root_dataset),
                                     daemon=True)
                while True:
                    if env_eval(modelProcessId):
                        break
                    else:
                        time.sleep(600)
                # 启动
                t.start()
                all_thread.append(t)


def system_resume():
    """
    恢复模型训练服务状态
    :return:
    """

    headers = {
        'Content-Type': 'application/json'
    }
    # 清空当前服务中的队列，避免重复启动同一个模型训练
    r1 = requests.post(url=f'http://localhost:{int(port)}/queue_size', headers=headers)
    r1_json = r1.json()
    logger.info('当前队列数量：%d' % r1_json['queue_left_number'])
    if r1_json['queue_left_number'] > 0:
        logger.info('正在消费队列，直到队列为空！')
        while True:
            r2 = requests.post(url=f'http://localhost:{int(port)}/subject_consumer', headers=headers)
            r2_json = r2.json()
            if r2_json['queue_left_number'] == 0:
                logger.info('队列消费完毕！可放心进行模型训练 ...')
                break
    else:
        logger.info('队列为空！可放心进行模型训练 ...')


def start_up_check():
    """
    启动前检查
    :return:
    """
    while True:
        try:
            headers = {
                'Content-Type': 'application/json'
            }
            r0 = requests.post(url=f'http://localhost:{int(port)}/queue_size', headers=headers)
            server_started = True
        except requests.exceptions.ConnectionError as e:
            server_started = False
            logger.error("Error: ConnectionError")
            logger.warning('服务未启动，请先启动server! 程序已退出。')
            exit(123)
            # logger.info('server正在尝试自启 ...')
            # time.sleep(3)
        if server_started:
            logger.info("server启动成功！模型训练服务已启动...")
            break


if __name__ == '__main__':
    # root_path = "../datasets/classification/zcjd_column_classify/zcjd_V0"
    # data_df = merge_df(root_path)
    # print(len(data_df))
    # print(data_df)
    # 开始启动模型训练服务
    start_up_check()
    logger.info('模型训练服务恢复中 ...')
    system_resume()
    time.sleep(30)
    logger.info('模型训练服务恢复完成！')
    logger.info('模型训练服务运行中 ...')
    system_start()

