#!/usr/bin/python3
import sys
from testhelpers import *
from testutil import *
from testexec import *
import testassertions
import os, glob
import datetime
from xml.dom.minidom import Document, parse
from threading import Thread
from queue import Queue
import random
import uuid
import shutil
from datetime import datetime
from commonutils import *
import traceback

def help():
    print("Usage: testcontroller hvs= <comma separated list of HV to use for the test> \n" \
          "\tseqdir=<absolute path of the file containing the sequences to run.> \n" \
          "\toutdir=<directory to place the test reports in> \n" \
          "\tseq=<comma separated sequence file names> members=<no of members for datastore> \n" \
          "\tstripes=<no of stripes> replicas=<no of replicas> \n" \
          "\tperf=<performance level for datastore> \n" \
          "\tid=<unique id to be used for this test>\n" \
          "\t[cleanup=true] [assert=true] [deamonize=true]\n" \
          "\t[hvconfig=<hvconfig path>] [prepare=true] \n" \
          "\t[restartgroupmon=true]\n" \
          "\t[refreshbuild=true]\n" \
          "\t[longhaul=true]\n" \
          "\t[times=<number of times>]\n" \
          "\t[randomds=true]\n" \
          "\t[randomtest=true]" \
          "\t[disablespacetest=true]" \
          "\t[bs= <comma separated list of BS to use for the test> \n" \
          "\t[epoch= <true/false> \n")

    sys.exit()

if len(sys.argv) < 2:
    help()

def ensureRequiredParamsPresent(dict, required_keys):
    for key in required_keys:
        if key not in dict:
            help()
            sys.exit()

class Test():
    def __init__(self, args):
        self.hvs = args['hvs'].split(',')
        self.dsparams = {}
        self.extractCmdLineArgs(args)
        self.outfolder = os.path.join(self.outdir, self.uuid)
        print(self.outfolder)
        os.mkdir(self.outfolder)
        self.statusfile = os.path.join(self.outfolder, 'status-RUNNING')
        touchFile(self.statusfile)
        if self.runAssertions:
            self.assertfile = os.path.join(self.outfolder, 'assert-True')
        else:
            self.assertfile = os.path.join(self.outfolder, 'assert-False')
        if self.type != MANUAL_TEST:
            self.rdlogfile = os.open(os.path.join(self.outfolder, 'summary'), os.O_CREAT | os.O_RDWR | os.O_SYNC)
            printToBoth("********  Welcome to OnAppTest  ********", self.rdlogfile)
            printToBoth("****************************************", self.rdlogfile)
            printToBoth("Test ID: %s" % self.uuid, self.rdlogfile)
            self.hvToIp = getHVToIP(self.hvs)
            self.hvToMac = getHVToMac(self.hvToIp)
        self.hvToUuid = getHVToUuid(self.hvs)
        self.hvToRpm = getHVToRpmVersion(self.hvs, self.hvToIp)
        self.hvToBuildtime = getHVToRpmBuildTime(self.hvs, self.hvToIp)
        # open output files 
        #self.parselogfile = open(os.path.join(self.outfolder, 'parsable-test-report.xml'), 'wb')
        self.parselogfile = None
        #self.rdlogfile = open(os.path.join(self.outfolder, 'readable-test-report'), 'wb')
        #self.path = os.path.join('/tmp', str(self.uuid))
        self.path = self.outfolder
        self.pidfile = os.path.join(self.outfolder, 'PID')
        seqfile = os.path.join(self.outfolder, 'SEQS--pending')
        touchFile(seqfile)
        cpfile = os.path.join(self.outfolder, 'CP--%s' % getCPVersion())
        touchFile(cpfile)
        onapp_store_file = os.path.join(self.outfolder, 'onappstore--%s' % getOnappstoreVersion())
        touchFile(onapp_store_file)
        epoch_file = os.path.join(self.outfolder, 'epoch--%s' % self.epoch)
        touchFile(epoch_file)
        overcommit_file = os.path.join(self.outfolder, 'overcommit--%s' % self.overcommit)
        touchFile(overcommit_file)

        # generate pid file
        setPID(self.pidfile)

    def run(self):
        try:
            status = 'PASS'
            # prepare HVs
            if 'hvconfig' in args:
                self.hvinfo = str(readHvInfo(os.path.abspath(args['hvconfig']))).replace(' ','')
                if self.prepare:
                    prepareHVs(self.hvinfo, self.rdlogfile, self.refreshbuild)
            else:
                self.hvinfo = ''

            self.hvToUse = getLiveHV(self.hvToIp)
            # read hv config
            self.hvconfig = getHVConfig(self.hvs, self.hvToIp, self.hvToUuid, bs=self.bs)

            # cleanup the SAN before the next sequence
            if self.clean:
                performCleanup(list(self.hvconfig.keys()), self.rdlogfile, self.restartgroupmon, self.hvToIp, self.runAssertions, cleanvms=True)

            if self.randomds:
                self.dsparams = getRandomDsParams(list(self.hvconfig.keys())[0])

            if len(self.datastore):
                refHV = getLiveHV(self.hvToIp)
                datastoreinfo = getDsInfo(self.datastore, refHV)
                self.dsparams = {}
                self.dsparams['uuid'] = self.datastore
                self.dsparams['M'] = len(datastoreinfo['members'].split(','))
                self.dsparams['R'] = datastoreinfo['replicas']
                self.dsparams['S'] = datastoreinfo['stripes']
                print(self.dsparams)
            if 'M' not in self.dsparams:
                self.dsparams['M'] = getNumberOfAvailableMembers(list(self.hvconfig.keys())[0], self.type != MANUAL_TEST)

            if self.dsparams == {}:
                dsflag = os.path.join(self.outfolder, 'DS-members=0_replicas=0_stripes=0_perf=0')
            elif 'P' in self.dsparams:
                dsflag = os.path.join(self.outfolder, 'DS-members=%s_replicas=%s_stripes=%s_perf=%s' %
                                      (self.dsparams['M'], self.dsparams['R'], self.dsparams['S'], self.dsparams['P']))
            else:
                dsflag = os.path.join(self.outfolder, 'DS-members=%s_replicas=%s_stripes=%s_perf=X' %
                                      (self.dsparams['M'], self.dsparams['R'], self.dsparams['S']))
            touchFile(dsflag)

            for hv in self.hvs:
                #hvinfo = getDeployedHVInfo(hv)
                hvflag = os.path.join(self.outfolder, 'HV-%s-%s-%s-%s-%s-%s-%d-%s' %
                                      #(hv, findHVType(hv), self.hvToRpm[hv], hvinfo['hostid'], hvinfo['mode']))
                                      (hv, findHVType(hv), self.hvToRpm[hv], getMacForHV(hv), 'PV', getUnicastmodeForHV(hv), getCentosVersion(hv), self.hvToBuildtime[hv]))
                #self.channelflag = os.path.join(self.outfolder, 'CHANNEL-%s' % hvinfo['channel'])
                self.channelflag = os.path.join(self.outfolder, 'CHANNEL-%s' % getChannelForHV(hv))
                touchFile(hvflag)
                touchFile(self.channelflag)

            touchFile(self.statusfile)
            type_file = os.path.join(self.outfolder, 'type-%s' % self.type)
            touchFile(type_file)

            if self.type != MANUAL_TEST:
                printHVInfo(self.hvToIp, self.rdlogfile, hvconfig=self.hvconfig, hvToRpm=self.hvToRpm)

            seqfile = os.path.join(self.outfolder, 'SEQS--pending')
            os.unlink(seqfile)
            seqfile = os.path.join(self.outfolder, 'SEQS--%s' % ','.join(self.seqtorun))
            touchFile(seqfile)

            # persist HV config to disk
            persistHVConfig(self.hvconfig)

            if self.type == MANUAL_TEST:
                self.ds = None
                return

                # read sequence config
            self.seqconfig = readSeqConfig(self.seqdir)

            # get sequences to run
            if not self.deamonize:
                #self.seqtorun = getSeqToRun(self.runfiles)
                printToBoth("Test configuration", self.rdlogfile)
                printToBoth("----------------", self.rdlogfile)
                printToBoth("----------------", self.rdlogfile)
                #for seq,mode,times in self.seqtorun:
                for seq in self.seqtorun:
                    #printToBoth("Sequence:%s Mode:%s No of times:%s" % (seq, mode, times), self.rdlogfile)
                    printToBoth("Sequence:%s.seq" % seq, self.rdlogfile)
                    printToBoth("----------------", self.rdlogfile)
                printToBoth("----------------", self.rdlogfile)

            # go through each sequence run line one by one and run each sequence
            self.results = {}
            seqindex = 1
            #for seq,mode,times in self.seqtorun:
            if self.longhaul:
                while True:
                    if self.times and seqindex > self.times:
                        break
                    seq = random.sample(self.seqtorun, 1)[0]
                    self.runSequence(seq, seqindex)
                    seqindex += 1
            else:
                for seq in self.seqtorun:
                    if seqindex != 1:
                        # cleanup the SAN before the next sequence
                        # perform a best effort cleanup of the san
                        performCleanup(list(self.hvconfig.keys()), self.rdlogfile, self.restartgroupmon, self.hvToIp, self.runAssertions, cleanvms=True)

                    self.runSequence(seq, seqindex)
                    seqindex += 1
        except Exception as e:
            printToBoth("Test failed with exception: %s, traceback: %s" % (str(e), traceback.format_exc()), self.rdlogfile)
            status = 'FAIL'

        finally:
            if self.ds != None:
                time.sleep(5)
                #raw_input("Datastore about to be deleted, check datastore state ... ")
                if 'P' in self.dsparams:
                    if self.deleteds:
                        putDS(self.hvToUse, self.path, self.ds, self.rdlogfile)
                else:
                    if self.deleteds:
                        justDeleteDS(self.hvToUse, self.ds, self.rdlogfile)

            if not self.deamonize and self.type != MANUAL_TEST:
                saveTestReport(self.results, self.uuid, self.rdlogfile, self.parselogfile, os.path.join(self.outfolder, 'summary'), os.path.join(self.outfolder, 'parsable-test-report.xml'))
            if os.path.exists(os.path.join('/tmp', 'hvstate')):
                shutil.rmtree(os.path.join('/tmp', 'hvstate'))
            resetPID(self.pidfile)
            if os.path.exists(self.channelflag):
                os.unlink(self.channelflag)

            # set correct status in the end
            if status != 'FAIL':
                status = 'PASS'
                for file in glob.glob('%s/*' % self.outfolder):
                    if os.path.basename(file).find('__FAIL') != -1:
                        status = 'FAIL'
                        break
            statusfile = os.path.join(self.outfolder, 'status-%s' % status)
            if self.type != MANUAL_TEST:
                if os.path.exists(self.statusfile):
                    os.unlink(self.statusfile)
                touchFile(statusfile)

    def runSequence(self, seq, seqindex):
        seq = seq + ".seq"
        mode = 'S'
        times = 'R1'
        #if mode.startswith('#'):
        #    continue

        self.ds = None
        # get datastore, create if required
        if self.dsparams != {}:
            if 'uuid' not in self.dsparams:
                dsinfo = getDS(self.dsparams, self.hvconfig, self.path, self.rdlogfile)
                self.ds = dsinfo['uuid']

                # if required, enable epoch on nodes for this datastore
                hvToUse = self.hvs[0]
                for node in dsinfo['owners'].split(','):
                    # call adjustNode for this node, with epoch 1
                    cmd = 'curl -SsX PUT localhost:8080/is/Node/%s -d \'{\"state\":\"1\",\"epoch\":\"%s\"}\'' % (node, self.epoch)
                    out = runCommandOnNodeAndGetOutput(hvToUse, cmd)

                self.deleteds = True
            else:
                self.ds = self.dsparams['uuid']
            self.dsparams['uuid'] = self.ds
        else:
            self.dsparams['uuid'] = ''
            self.dsparams['R'] = ''
            self.dsparams['S'] = ''

        threads = []
        currTime = datetime.now()
        iteration = 0
        while True:
            if times.startswith('R'):
                if int(times[1:]) <= iteration:
                    break

            if times.startswith('T'):
                d = datetime.now() - currTime
                if d.seconds > (int(times[1:]) * 60):
                    break

            iteration += 1
            randomize = False
            if mode == 'M':
                randomize = True

            # run this sequence in a separate thread
            printToBoth("Running sequence: %s, Iteration: %d" % (seq, iteration), self.rdlogfile)
            q = Queue()
            t = runSeqInThread(self.hvconfig, self.seqconfig[seq], seq, self.path, q, self.hvToIp, self.hvToMac, self.hvToUuid, self.dsparams, randomize, seqindex, iteration, self.runAssertions, self.hvinfo, self.rdlogfile, self.longhaul, self.randomtest, self.disablespacetest, self.bs, self.ds)
            t.start()
            t.join()
            threads.append((t, q))

        allover = False
        while not allover:
            allover = True
            for t,q in threads:
                if t.isAlive():
                    allover = False
                    break
            time.sleep(1)
        for (t,q) in threads:
            (seq, testfileToThreadNoToPair) = q.get()
            if seqindex in self.results:
                self.results[seqindex].append((seq,testfileToThreadNoToPair))
            else:
                self.results[seqindex] = [(seq,testfileToThreadNoToPair)]

    def extractCmdLineArgs(self, args):
        self.seqtorun = args['seq'].split(',')

        if 'datastore' in args:
            self.datastore = args['datastore']
            self.clean = False
            self.deleteds = False
        else:
            self.datastore = ''

        self.overcommit = '0'
        if 'randomds' in args and args['randomds'] == 'true':
            self.randomds = True
        else:
            self.randomds = False
            if not len(self.datastore):
                if 'members' in args:
                    self.dsparams['M'] = args['members']
                self.dsparams['R'] = args['replicas']
                self.dsparams['S'] = args['stripes']
                if 'overcommit' in args:
                    self.overcommit = args['overcommit']
                    self.dsparams['O'] = args['overcommit']
        if 'perf' in args:
            self.dsparams['P'] = args['perf']
        self.seqdir =os.path.abspath(args['seqdir'])
        self.outdir = os.path.abspath(args['outdir'])
        if 'assert' in args and args['assert'] == 'true':
            self.runAssertions = True
        else:
            self.runAssertions = False

        if 'clean' in args and args['clean'] == 'true':
            self.clean = True
        else:
            self.clean = False

        if 'deamonize' in args and args['deamonize'] == 'true':
            self.deamonize = True
        else:
            self.deamonize = False

        if 'prepare' in args and args['prepare'] == 'true':
            self.prepare = True
        else:
            self.prepare = False

        if 'restartgroupmon' in args and args['restartgroupmon'] == 'true':
            self.restartgroupmon = True
        else:
            self.restartgroupmon = False

        if 'refreshbuild' in args and args['refreshbuild'] == 'true':
            self.refreshbuild = True
        else:
            self.refreshbuild = False

        if 'id' in args:
            self.uuid = args['id']
        else:
            self.uuid = time.strftime('%d-%m-%Y_%H-%M-%S', time.gmtime())

        if 'longhaul' in args and args['longhaul'] == 'true':
            self.longhaul = True
        else:
            self.longhaul = False

        self.times = 0
        if 'times' in args:
            self.times = int(args['times'])

        if 'randomtest' in args and args['randomtest'] == 'true':
            self.randomtest = True
        else:
            self.randomtest = False

        if 'disablespacetest' in args and args['disablespacetest'] == 'true':
            self.disablespacetest = True
        else:
            self.disablespacetest = False

        if 'bs' in args:
            self.bs = args['bs']
        else:
            self.bs = '0.0.0.0'

        if 'type' in args and args['type'] == MANUAL_TEST:
            self.type = MANUAL_TEST
        else:
            self.type = AUTO_TEST

        self.epoch = '0'
        if 'epoch' in args and args['epoch'] =='true':
            self.epoch = '1'

# if there are arguments parse them into a dictionary
args = parseargs(sys.argv)
ensureRequiredParamsPresent(args, ['hvs', 'seqdir','outdir', 'seq'])

for hv in args['hvs'].split(','):
    if not checkLiveness(hv, noretry=True):
        print("HV %s not reachable, please specify HVs which are active and running." % hv)
        sys.exit()

if 'id' not in args:
    uuid = time.strftime('%d-%m-%Y_%H-%M-%S', time.gmtime())
else:
    uuid = args['id']
print("Test ID: %s" % uuid)
args['id'] = uuid

if 'deamonize' in args and args['deamonize'] == 'true':
    deamonize()
    #pass
t = Test(args)
t.run()
