#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @File    : app_run.py
# @Time    : 2023/3/31 10:31
# @Author  : bruxelles_li
# @Software: PyCharm
import json
import os
import sys
import logging
import requests
import argparse
import queue
from pathlib import Path
from flask import Flask, jsonify, request
import re

# 模型训练服务
sys.path.append('../')
# 关闭多余连接
s = requests.session()
s.keep_alive = False
from classification.config.config_fast_text import FastTextConfig
from classification.runner.runner_fast_text import FastTextRunner

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] <%(processName)s> (%(threadName)s) %('
                                               'message)s')
logger = logging.getLogger(__name__)

# Queue基本FIFO队列  先进先出 FIFO即First in First Out,先进先出
# maxsize设置队列中，数据上限，小于或等于0则不限制，容器中大于这个数则阻塞，直到队列中的数据被消掉
q = queue.Queue(maxsize=0)
# 定义训练配置文件
train_config_path = '../classification/config/fasttext_config_train.yml'
# 定义应用配置文件
pred_config_path = '../classification/config/fasttext_config_pred.yml'

# 关闭多余连接
s = requests.session()
s.keep_alive = False
UPLOAD_FOLDER = r'../datasets/Receive_File'  # 上传路径
Path(UPLOAD_FOLDER).mkdir(parents=True, exist_ok=True)

TEMPFILE_FOLDER = UPLOAD_FOLDER + "/" + "Temp_file"
Path(TEMPFILE_FOLDER).mkdir(parents=True, exist_ok=True)

ALLOWED_EXTENSIONS = set(['xls', 'xlsx'])  # 允许上传的文件类型
app = Flask(__name__)
# 定义上传文件临时路径
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['TEMPFILE_FOLDER'] = TEMPFILE_FOLDER

from base.app.base_app import *
# 定义模型训练配套服务
from File_Operation.Operation import operation_file
# 此处需要变动
# from classification.app.app_ssyw_column_classify import classification_ssyw_column_classify

# operation_file 模型训练配套服务
operation_prefix = "/platform/operation/process"  # 上传、删除、测试、发布
app.register_blueprint(operation_file, url_prefix="{}".format(operation_prefix))

# classification 训练接口参数
train_url = "/platform/classification/FastText-Model/model_train/"
model_name = "FastText-Model"

# 定义应用url
application_url = "/model_pred/"

# 定义测试url
model_test_url = "/model_test/"

# 跨域支持1
from flask_cors import CORS

CORS(app, supports_credentials=True)


@app.route('/', methods=['POST'])
def hello_world():
    app.logger.info('请选择正确的方式上传!')
    return '请选择正确的方式上传!'


@app.route('/subject_consumer', methods=['GET', 'POST'])
def subject_consumer():
    if not q.empty():
        config_info = q.get()
        return jsonify(message='当前队列数量：' + str(q.qsize()),
                       queue_left_number=str(q.qsize()),
                       data=config_info)
    else:
        return jsonify(message='队列为空！', queue_left_number=0)


@app.route('/queue_size', methods=['GET', 'POST'])
def queue_size():
    return jsonify(queue_left_number=q.qsize())


@app.route("/platform/classification/FastText-Model/model_train/", methods=['POST'])
def model_train():
    """
    {   'reg_lambda': 1,
        'scale_pos_weight': 1,
        'reg_alpha': 1,
        'modelProcessId': '1453295295008211969',
        'learning_rate': '0.02',
        'gpu': False,
        'min_child_weight': 1,
        'train': 'http://39.105.62.235:7000/br/classification/project_info/filter/train',
        'data_path': '/datasets/the_belt_and_road/classification/pro_info_filter'
    }
    -> data:
    :return:
    """
    try:
        data = json.loads(request.data.decode('utf-8'))
        modelProcessId = data['modelProcessId']
        task_id = data["task_id"]
        learning_rate = data['learning_rate'] if 'learning_rate' in data else 0.02
        epoch = data['epoch'] if "epoch" in data else 5
        gpu = data['gpu'] if 'gpu' in data else None
        # 模型版本  语料版本  名称
        data_path_0 = data['data_path']
        model_path_0 = data['model_path']
        data_path_1 = "/" + task_id + "/" + data_path_0.strip('/')
        model_path_1 = '/' + task_id + "/" + model_path_0.strip('/')
        # 加载配置文件获取语料和模型文件存放路径
        _config = FastTextConfig(config_path=train_config_path).load_config()
        # todo: 先做数据集检测，分两种情况来处理
        data_temp_path = _config.data.path0 % data_path_1
        app.logger.info(data_temp_path)
        # todo: 接着做模型路径检测，如果当前模型版本已经存在，则提示当前版本已经存在
        temp_path = _config.learn.dir.saved % model_path_1
        app.logger.info(temp_path)
        if os.path.exists(data_temp_path):
            pass
        else:
            dict_result = {
                'code': 500,
                'isHandleSuccess': False,
                'logs': '模型训练失败！当前模型训练的语料文件不存在，请上传语料后再进行训练',
                'result': None
            }
            app.logger.info(dict_result)
            return json.dumps(dict_result, ensure_ascii=False)

        # todo: 模型版本检测是都需要的，线程执行就好
        if not os.path.exists(temp_path):
            model_path = model_path_1
        else:
            dict_result = {
                'code': 500,
                'isHandleSuccess': False,
                'logs': '模型训练失败！当前模型版本已经存在，请更改模型版本号再重新进行训练',
                'result': None
            }
            app.logger.info(dict_result)
            return json.dumps(dict_result, ensure_ascii=False)

        # 模型保存版本 和 数据集 无异常，则开始进入模型训练部分
        VER = data_path_1
        root_dataset = data_temp_path
        app.logger.info(root_dataset)
        config_info = {
            "modelProcessId": modelProcessId,
            "data_path": VER,
            "model_path": model_path,
            'root_dataset': root_dataset
        }
        q.put(config_info)
        app.logger.info(config_info)
        dict_result = {
            'code': 200,
            'isHandleSuccess': True,
            'logs': '模型训练中 ...',
            'result': None
        }
    except Exception as e:
        dict_result = {
            'code': 500,
            'isHandleSuccess': False,
            'logs': '训练失败！' + str(e),
            'result': None
        }
    app.logger.info(dict_result)
    return json.dumps(dict_result, ensure_ascii=False)


@app.route("/platform/classification/FastText-Model/model_test/", methods=['POST'])
def model_test():
    """
    {
        'threshold': 0.5,
        'model_path': '/zzsn_nlp_br/classification/model/model_saved',
        'url': 'http://39.105.62.235:7000/br/classification/project_info/filter/model_test'
    }
    -> data:
    :return:
    """
    try:
        data = json.loads(request.data.decode('utf-8'))
        title = data['title'] if 'title' in data else None
        content = data['content'] if 'content' in data else None
        model_path = data['model_path'] if 'model_path' in data else None

        runner_test = FastTextRunner(config_path=pred_config_path, model_path=model_path)
        dict_result = runner_test.test(title=title, content=content)
        if dict_result['code'] != 200:
            dict_result['logs'] = '模型测试失败！' + dict_result['logs']
    except Exception as e:
        dict_result = {
            'handleMsg': 'failure',
            'code': 500,
            'logs': '模型测试失败！' + str(e),
            'resultData': None
        }

    app.logger.info(dict_result)

    return json.dumps(dict_result, ensure_ascii=False)


@app.route("/platform/classification/FastText-Model/model_pred/", methods=['POST'])
def model_pred():
    try:
        data_list = json.loads(request.data.decode('utf-8'))
        result_one = []
        for data in data_list:
            title = data['title'] if 'title' in data else None
            content = data['content'] if 'content' in data else None
            infoId = data["id"] if "id" in data else None
            level2 = ssyw_runner.pred(
                title=title,
                content=content
            ).strip()
            result_one.append({
                "id": infoId,
                'labels': level2
            })
        dict_result = {
            'code': 200,
            'message': "操作成功",
            'result': result_one
        }
    except Exception as e:
        dict_result = {
            'code': 500,
            'success': 'false',
            'message': "操作失败" + str(e),
            'result': None
        }

    app.logger.info(dict_result)

    return json.dumps(dict_result, ensure_ascii=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('-port', dest='port', help='', default=4005)
    parser.add_argument('-host', dest='host', help='', default='0.0.0.0')

    # 微服务参数
    parser.add_argument('-model_path', dest='model_path', help='', default='')
    parser.add_argument('-micro_server_port', dest='micro_server_port', help='', default=None)

    args = parser.parse_args()
    if args.model_path and args.micro_server_port:
        model_path = os.path.join(args.model_path, "model.bin")
        ssyw_runner = FastTextRunner(config_path=pred_config_path, model_path=model_path)
        app.run(host=args.host,
                port=int(args.micro_server_port)
                )
    else:
        ssyw_runner = FastTextRunner(config_path=pred_config_path)
        app.run(host=args.host,
                port=int(args.port)
                )




