#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time    : 2021/5/11 20:33
# @Author  : 程婷婷
# @FileName: XgboostClassifyRunner.py
# @Software: PyCharm
import time
import numpy as np
from model.base.views.runner.BaseRunner import BaseRunner
from model.classify.views.xgboost_classify.data.XgboostClassifyProcess import XgboostClassifyProcess
from model.classify.views.xgboost_classify.XgboostClassifyModel import XgboostClassifyModel
from model.classify.views.xgboost_classify.XgboostClassifyEvaluator import XgboostClassifyEvaluator


class XgboostClassifyRunner(BaseRunner):
    def __init__(self, config_path):
        super().__init__(config_path)
        self.signature = int(time.time())
        self.xcp = XgboostClassifyProcess(config_path)
        self.xcm = XgboostClassifyModel(config_path)
        self.xce = XgboostClassifyEvaluator(config_path)

    def train(self, logger):
        train_set, test_set = self.xcp.runner_process(signature=self.signature)
        print(self.xcp.label_mapping)
        label_dict,   = self.xcp.label_mapping,
        X_train = np.delete(train_set, -1, axis=1)
        Y_train = train_set[:, -1].astype(np.int64)
        print(X_train.shape)
        print(Y_train)
        print(list(set(Y_train)))
        logger.info('处理后的数据量为 %d 条' %(len(train_set)+len(test_set)))
        logger.info('训练集的数据量为 %d 条'%len(train_set))
        logger.info('测试集的数据量为 %d 条'%len(test_set))
        print('==========训练集有%d条数据==========' %len(X_train))
        model = self.xcm.building_model(
            label_dict,
            self.signature,
            X_train,
            Y_train
        )
        # xg = XgboostClassify(label_dict=self.xcp.label_mapping, signature=self.signature)
        X_test = np.delete(test_set, -1, axis=1)
        true_label = test_set[:, -1].astype(np.int64)
        print(list(set(true_label)))
        predict_label = model.predict(X_test, thres=self.runner_config['thres'])
        predict_label = predict_label.tolist()
        print(list(set(predict_label)))
        self.xce.evaluate(true_label, predict_label, label_dict, logger)
        return 'success'


# if __name__ == '__main__':
#     state = XgboostClassifyRunner().train()
#     print(state)
