#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time    : 2021/5/11 20:33
# @Author  : 程婷婷
# @FileName: XgboostClassifyRunner.py
# @Software: PyCharm
from model.base import BaseRunner
from model.classify import FastTextProcess
from model.classify import FastTextModel
from model.classify import FastTextEvaluator


class FastTextRunner(BaseRunner.BaseRunner):
    def __init__(self, config_path):
        super().__init__(config_path)
        self.ftp = FastTextProcess.FastTextProcess(config_path)
        self.ftm = FastTextModel.FastTextModel(config_path)
        self.fte = FastTextEvaluator.FastTextEvaluator(config_path)

    def train(self, logger):
        train_path, test_path = self.ftp.runner_process(logger)
        model = self.ftm.building_model(input=train_path, autotuneValidationFile=test_path)
        with open(test_path, encoding='utf8') as file:
            test_data = file.readlines()
        true_labels, predict_labels = [], []
        for text in test_data:
            label = text.replace('__label__', '')[0]
            text = text.replace('__label__', '')[1:-1]
            true_labels.append(int(label))
            predict_label = model.predict(text)[0][0].replace('__label__', '')
            # print(pre_label)
            predict_labels.append(int(predict_label))
        evaluate_result = self.fte.evaluate(true_labels, predict_labels, label_mapping=None, logger=logger)
        print(evaluate_result)
        return 'success'
# if __name__ == '__main__':
#     state = FastTextRunner().train()
#     print(state)