import sys
from os import listdir, scandir, walk
from os.path import isfile, join
import traceback

from Scanning import scan
from Parsing import parse
from AstBuilding import astBuild, buildEnvAndLink, disamiguateAndTypeChecking
import Weeding


# run test with python Test.py Tests/A2/J1_3_ImportOnDemand_ProgramDefinedPackage/


def main():
    a2Multiple()

def allFiles(testDir):
     return [testDir + f for f in listdir(testDir) if isfile(join(testDir, f)) and f.startswith('J1')]

def a1():
    if len(sys.argv) > 1:
        testFiles = ["./Tests/" + f for f in sys.argv[1:]]
    else:
        # All files in the test directory
        testDirectory = "./Tests/"
        testFiles = allFiles(testDirectory)
    print("**********************************************************")
    run(testFiles)

def a2Single():
    # test all the single file cases in a2
    pass

def a2Multiple():
    if len(sys.argv) > 1:
        testCases = [f for f in sys.argv[1:]]
    else:
        # All files in the test directory
        testDirectory = "./Tests/A3/"
        testCases = [f.path for f in scandir(testDirectory) if f.is_dir()]
        testCases += [f.path for f in scandir(testDirectory) if not f.is_dir()]

    total = 0
    correct = 0

    for c in testCases:
        # get all files from stdlib folder
        testFiles = [join(dp, f) for dp, dn, filenames in walk('stdlib/3.0/java/') for f in filenames]

        if '.java' in c:
            # add this one file
            testFiles.append(c)
        else:
            # get all files in the folder recursively
            testFiles += [join(dp, f) for dp, dn, filenames in walk(c) for f in filenames]

        ret = run(testFiles)
        total += 1

        if ret == "":
            if 'Je_' in c:
                print(c)
                print("JE Passed without error")
                print("**********************************************************")
            else:
                correct += 1
        else:
            if not 'Je_' in c:
                print(c)
                print(ret)
                print("**********************************************************")
            else:
                correct += 1

    print("\nSCORE: {} / {} -> {:.3g}%".format(correct, total, (correct/total)*100))



def run(testFiles):
    parseTrees = []

    for f in testFiles:
        if f.split("/")[-1] == ".DS_Store":
            continue
        content = open(f, "r").read()

        # Scanning
        (tokens, error) = scan(content)

        # Error in Scanning
        if tokens is None:
            if not 'Je_' in f:
                print(f)
                print("ERROR in Scanning: " + error)
                print("**********************************************************")
            return "ERROR in scanning"
        
        # s = "All Tokens: "
        # for token in tokens:
        #     if (token.name and token.lex):
        #         s += '(' + token.name + ',' + token.lex + '), '
        # print(s)

        # Weeding after scanning
        # No weeds if everything is good (weeds = None)
        weeds = Weeding.fileNameCheck(tokens, f)
        if weeds:
            print(f)
            print("ERROR in Weeding after Scanning:")
            print(weeds)
            print("**********************************************************")
            return "ERROR in weeding"

        # Parsing
        tree = None
        try:
            (tree, error) = parse(tokens)
        except:
            print("Exception in Parsing")

        # Error in Parsing
        if tree is None:
            if not 'Je_' in f:
                print(f)
                print("ERROR in Parsing: " + error.args[0])
                # for n in error.args[1]: # the parse tree
                #     print(n)
                print("**********************************************************")
            return "ERROR in parsing"

        parseTrees.append((f, tree))

    # for (n, t) in parseTrees:
    #     if n == "Tests/A3/J1_accessstaticfield/Main.java":
    #         print(n)
    #         print(t)

    ASTs = astBuild(parseTrees)

    # for (n, t) in ASTs:
    #     print(n)
    #     print("--------------------")
    #     t.printTree()
    #     print("\n \n\n \n")

    try:
        buildEnvAndLink(ASTs)
    except Exception as e:
        return "buildEnvAndLink: " + e.args[0]

    try:
        disamiguateAndTypeChecking(ASTs)
    except Exception as e:
        return "disamiguateAndTypeChecking: " + e.args[0]

    return ""

main()