#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time    : 2021/5/11 20:33
# @Author  : 程婷婷
# @FileName: XgboostClassifyRunner.py
# @Software: PyCharm
from model.base.views.runner.BaseRunner import BaseRunner
from model.classify.views.logistic_classify.data.LogisticClassifyProcess import LogisticClassifyProcess
from model.classify.views.logistic_classify.LogisticClassifyModel import LogisticClassifyModel
from model.classify.views.logistic_classify.LogisticClassifyEvaluator import LogisticClassifyEvaluator


class LogisticClassifyRunner(BaseRunner):
    def __init__(self, config_path):
        super().__init__(config_path)
        self.lcp = LogisticClassifyProcess(config_path)
        self.lcm = LogisticClassifyModel(config_path)
        self.lce = LogisticClassifyEvaluator(config_path)

    def train(self, logger):
        tfidf_title, idf_title, labels = self.lcp.title_process(logger)
        Threshold,Index_Retain_Predict_Title,Index_Delete_Title = self.lcm.building_model(
            tfidf_title=tfidf_title,
            labels=labels,
            logger=logger
        )
        tfidf_content, idf_content = self.lcp.content_process(Index_Retain_Predict_Title)

        threshold, Index_Retain_Predict_Content, Index_Delete_Content = self.lcm.building_model(
            labels = labels,
            tfidf_content=tfidf_content,
            r=0.8,
            logger=logger
        )  # r可调节，训练最终在召回率低于r时终止。
        return 'success'


# if __name__ == '__main__':
#     state = LogisticClassifyRunner().train()
#     print(state)
