#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @File     : runner_fast_text
# @Author   : LiuYan
# @Time     : 2021/4/15 16:44

import os
import sys
import time
import json
import warnings
import fasttext
import pandas as pd

from pathlib import Path

sys.path.append('../../')
from utils.utils import timeit
from base.runner.base_runner import BaseRunner, train_BaseRunner
from classification.config.config_fast_text import FastTextConfig
from classification.evaluation.classify_evaluator import ClassifyEvaluator
from classification.utils.utils import *
warnings.filterwarnings('ignore')
fasttext.FastText.eprint = lambda x: None


class FastTextRunner_train(train_BaseRunner):
    def __init__(self, config_path: str, model_train=False, model_path=None):
        super(FastTextRunner_train, self).__init__()
        self._config_path = config_path
        self._config = None

        self._time = time.strftime('%Y_%m_%d-%H_%M_%S')
        self._model_train = model_train
        self._model_path = model_path

        self._train_dataloader = None
        self._valid_dataloader = None
        self._test_dataloader = None

        self._model = None
        self._loss = None
        self._optimizer = None

        self._evaluator = None
        self._build()

    @timeit
    def _build(self):
        self._build_config()
        # self._time = self._config.learn.time
        self._build_data()
        self._build_model()
        self._build_loss()
        self._build_optimizer()
        self._build_evaluator()
        pass

    @timeit
    def _build_config(self):
        self._config = FastTextConfig(config_path=self._config_path).load_config()
        pass

    @timeit
    def _build_data(self):
        if self._config.status in ['train', 'test'] or self._model_train:
            self._train_path = self._config.data.train_path
            self._valid_path = self._config.data.valid_path
            self._test_path = self._config.data.test_path
        else:
            self._stop_words = stop_words(
                path=r'../../word2vec/f_zp_gp/stop_words.txt'
            )
            # self._stop_words = stop_words(
            #     path=os.path.join(self._config.home.dir, 'word2vec/f_zp_gp/stop_words.txt')
            # )
        pass

    @timeit
    def _build_model(self):
        if self._model_path:
            self._config.learn.dir.load_model = self._model_path
        if self._config.status in ['test', 'pred'] and not self._model_train:
            self._load_model()
        pass

    @timeit
    def _build_loss(self):
        pass

    @timeit
    def _build_optimizer(self):
        pass

    @timeit
    def _build_evaluator(self):
        self._evaluator = ClassifyEvaluator()
        pass

    @timeit
    def train(self, data_path, model_path, auto_tune_duration=500, auto_tune_model_size='200M'):
        self._model = fasttext.train_supervised(
            input=self._train_path % data_path, autotuneValidationFile=self._test_path % data_path,
            autotuneDuration=auto_tune_duration, autotuneModelSize=auto_tune_model_size
        )
        self._save_model(model_path)
        pass

    def _train_epoch(self, epoch: int):
        pass

    def _valid(self, data_path, model_path, epoch: int) -> None or dict:
        with open(self._valid_path % data_path, encoding='utf-8') as file:
            self._valid_dataloader = file.readlines()
        labels = []
        pre_labels = []
        for text in self._valid_dataloader:
            label = text.replace('__label__', '').split(' ')[0]
            labels.append(label)
            text = text.replace('__label__', '')[1: -1]
            # pred_labels, pred_pros = self._model.predict(text, k=2)
            # for pred_label, pred_prob in zip(pred_labels, pred_pros):
            #     print(pred_label, pred_prob)
            pre_label = self._model.predict(text)[0][0].replace('__label__', '')
            # print(pre_label, self._model.predict(text))
            pre_labels.append(pre_label)

        p, r, f1, dict_result = self._evaluator.evaluate(true_list=labels, pred_list=pre_labels)

        if self._config.status == 'train' or self._model_train:
            json_result = json.dumps(dict_result)
            with open(self._config.learn.dir.saved % model_path + '-{}/evaluation_metrics.json'.format(self._time),
                      'w', encoding='utf-8') as f:
                f.write(json_result)
        if self._model_train:
            dict_result = {
                'code': 200,
                'result': '模型训练成功！模型评测指标为: precision: {:.0f}%  recall: {:.0f}%  f1-score: {:.0f}%'.format(
                    dict_result['average']['precision'] * 100,
                    dict_result['average']['recall'] * 100,
                    dict_result['average']['f1-score'] * 100
                ),
                'model_path': self._config.learn.dir.saved % model_path + '-{}/model.bin'.format(self._time)
            }
            return dict_result

    def test(self, data_path=None, model_path=None, title=None, content=None) -> None or dict:
        if self._model_train:
            return self._valid(data_path=data_path, model_path=model_path, epoch=100)
        elif self._model_path:
            with open(
                    os.path.join(os.path.split(self._model_path)[0], 'evaluation_metrics.json'),
                    'r', encoding='utf-8'
            ) as f:
                json_result = json.load(f)
            evaluation_metrics = {
                '精确率(P)': '{:.0f}%'.format(json_result['average']['precision'] * 100),
                '召回率(R)': '{:.0f}%'.format(json_result['average']['recall'] * 100),
                'F1值(F1)': '{:.0f}%'.format(json_result['average']['f1-score'] * 100)
            }
            result = self.pred(title=title, content=content)
            dict_result = {
                'handleMsg': 'success',
                'code': 200,
                'logs': '模型测试成功！',
                'result': {
                    'label': result,
                    'evaluation_metrics': evaluation_metrics
                }
            } if type(result) == str else result
            return dict_result
        else:
            self._valid(data_path=data_path, model_path=model_path, epoch=100)

    def pred(self, title: str, content: str) -> str or dict:
        text = (title + '。') * 2 + content
        text = clean_txt(raw=clean_tag(text=text))
        if type(text) is str:
            text = text.replace('\n', '').replace('\r', '').replace('\t', '')
        else:
            return {
                'handleMsg': 'failure',
                'code': 300,
                'logs': '{} is not str!'.format(text),
                'result': {
                    'label': None
                }
            }
        text = seg(text=text, sw=self._stop_words)
        pre_label = self._model.predict(text)[0][0].replace('__label__', '')

        return pre_label

    def pred_file(self, file_path: str, result_path: str) -> None or dict:
        data_loader = pd.read_excel(file_path)
        titles, contents = data_loader['title'], data_loader['content']
        labels = []
        for title, content in zip(titles, contents):
            pred_result = self.pred(title, content)
            if type(pred_result) == str:
                labels.append('是' if pred_result == '1' else '否')
            else:
                return pred_result

        data_loader['label'] = labels
        data_loader.to_excel(result_path)

    def _display_result(self, dict_result: dict):
        pass

    def _save_model(self, model_path):
        print(self._config.learn.dir.saved % model_path + '-{}/model.bin'.format(self._time))
        Path(self._config.learn.dir.saved % model_path + '-{}'.format(self._time)).mkdir(parents=True, exist_ok=True)
        self._model.save_model(self._config.learn.dir.saved % model_path + '-{}/model.bin'.format(self._time))

    def _load_model(self):
        self._model = fasttext.load_model(self._config.learn.dir.load_model)


class FastTextRunner(BaseRunner):
    def __init__(self, config_path: str, model_train=False, model_path=None):
        super(FastTextRunner, self).__init__()
        self._config_path = config_path
        self._config = None

        self._time = time.strftime('%Y_%m_%d-%H_%M_%S')
        self._model_train = model_train
        self._model_path = model_path

        self._train_dataloader = None
        self._valid_dataloader = None
        self._test_dataloader = None

        self._model = None
        self._loss = None
        self._optimizer = None

        self._evaluator = None
        self._build()

    @timeit
    def _build(self):
        self._build_config()
        self._build_data()
        self._build_model()
        self._build_loss()
        self._build_optimizer()
        self._build_evaluator()
        pass

    @timeit
    def _build_config(self):
        self._config = FastTextConfig(config_path=self._config_path).load_config()
        pass

    @timeit
    def _build_data(self):
        if self._config.status in ['train', 'test'] or self._model_train:
            self._train_path = self._config.data.train_path
            self._valid_path = self._config.data.valid_path
            self._test_path = self._config.data.test_path
        else:
            self._stop_words = stop_words(
                path=os.path.join(self._config.data.dir, '../word2vec/f_zp_gp/stop_words.txt')
            )
            # self._stop_words = stop_words(
            #     path=os.path.join(self._config.home.dir, 'word2vec/f_zp_gp/stop_words.txt')
            # )
        pass

    @timeit
    def _build_model(self):
        if self._model_path:
            self._config.learn.dir.load_model = self._model_path
        if self._config.status in ['test', 'pred'] and not self._model_train:
            self._load_model()
        pass

    @timeit
    def _build_loss(self):
        pass

    @timeit
    def _build_optimizer(self):
        pass

    @timeit
    def _build_evaluator(self):
        self._evaluator = ClassifyEvaluator()
        pass

    @timeit
    def train(self, auto_tune_duration=5000, auto_tune_model_size='200M'):
        self._model = fasttext.train_supervised(
            input=self._train_path, autotuneValidationFile=self._test_path,
            autotuneDuration=auto_tune_duration, autotuneModelSize=auto_tune_model_size
        )
        self._save_model()
        pass

    def _train_epoch(self, epoch: int):
        pass

    def _valid(self, epoch: int) -> None or dict:
        with open(self._valid_path, encoding='utf-8') as file:
            self._valid_dataloader = file.readlines()
        labels = []
        pre_labels = []
        for text in self._valid_dataloader:
            label = text.replace('__label__', '').split(' ')[0]
            labels.append(label)
            text = text.replace('__label__', '')[1: -1]
            # pred_labels, pred_pros = self._model.predict(text, k=2)
            # for pred_label, pred_prob in zip(pred_labels, pred_pros):
            #     print(pred_label, pred_prob)
            pre_label = self._model.predict(text)[0][0].replace('__label__', '')
            # print(pre_label, self._model.predict(text))
            pre_labels.append(pre_label)

        p, r, f1, dict_result = self._evaluator.evaluate(true_list=labels, pred_list=pre_labels)

        if self._config.status == 'train' or self._model_train:
            json_result = json.dumps(dict_result)
            with open(self._config.learn.dir.saved + '-{}/evaluation_metrics.json'.format(self._time),
                      'w', encoding='utf-8') as f:
                f.write(json_result)
        if self._model_train:
            dict_result = {
                'code': 200,
                'result': '模型训练成功！模型评测指标为: precision: {:.0f}%  recall: {:.0f}%  f1-score: {:.0f}%'.format(
                    dict_result['average']['precision'] * 100,
                    dict_result['average']['recall'] * 100,
                    dict_result['average']['f1-score'] * 100
                ),
                'model_path': self._config.learn.dir.saved + '-{}/model.bin'.format(self._time)
            }
            return dict_result

    def test(self, title=None, content=None) -> None or dict:
        if self._model_train:
            return self._valid(epoch=100)
        elif self._model_path:
            with open(
                    os.path.join(os.path.split(self._model_path)[0], 'evaluation_metrics.json'),
                    'r', encoding='utf-8'
            ) as f:
                json_result = json.load(f)
            evaluation_metrics = {
                '精确率(P)': '{:.0f}%'.format(json_result['average']['precision'] * 100),
                '召回率(R)': '{:.0f}%'.format(json_result['average']['recall'] * 100),
                'F1值(F1)': '{:.0f}%'.format(json_result['average']['f1-score'] * 100)
            }
            result = self.pred(title=title, content=content)
            dict_result = {
                'handleMsg': 'success',
                'code': 200,
                'logs': '模型测试成功！',
                'result': {
                    'label': result,
                    'evaluation_metrics': evaluation_metrics
                }
            } if type(result) == str else result
            return dict_result
        else:
            self._valid(epoch=100)

    def pred(self, title: str, content: str) -> str or dict:
        text = (title + '。') * 2 + content
        text = clean_txt(raw=clean_tag(text=text))
        if type(text) is str:
            text = text.replace('\n', '').replace('\r', '').replace('\t', '')
        else:
            return {
                'handleMsg': 'failure',
                'code': 500,
                'logs': '{} is not str!'.format(text),
                'result': {
                    'label': None
                }
            }
        text = seg(text=text, sw=self._stop_words)
        pre_label = self._model.predict(text)[0][0].replace('__label__', '')

        return pre_label

    def pred_file(self, file_path: str, result_path: str) -> None or dict:
        data_loader = pd.read_excel(file_path)
        titles, contents = data_loader['title'], data_loader['content']
        labels = []
        for title, content in zip(titles, contents):
            pred_result = self.pred(title, content)
            if type(pred_result) == str:
                labels.append('是' if pred_result == '1' else '否')
            else:
                return pred_result

        data_loader['label'] = labels
        data_loader.to_excel(result_path)

    def _display_result(self, dict_result: dict):
        pass

    def _save_model(self):
        Path(self._config.learn.dir.saved + '-{}'.format(self._time)).mkdir(parents=True, exist_ok=True)
        self._model.save_model(self._config.learn.dir.saved + '-{}/model.bin'.format(self._time))

    def _load_model(self):
        self._model = fasttext.load_model(self._config.learn.dir.load_model)


if __name__ == '__main__':
    # 一带一路 项目资讯识别筛选模型
    ft_config_path = '../config/config_br_pro_info_filter.yml'
    # 一带一路 项目信息知识分类模型
    # ft_config_path = '../config/config_br_pro_info_type.yml'
    # 一带一路 项目商机信息识别分析模型
    # ft_config_path = '../config/config_br_buss_op_recognition.yml'
    # 一带一路 项目风险信息识别分析模型
    # ft_config_path = '../config/config_br_pro_risk_recognition.yml'
    # 一带一路 项目资讯正负面信息分析模型
    # ft_config_path = '../config/config_br_pro_sentiment_analysis.yml'

    runner = FastTextRunner(config_path=ft_config_path)
    # runner.train(
    #     auto_tune_duration=15000
    # )
    runner.test()
