import network
import socket
import machine
import ujson
import uhashlib
import ubinascii
import gc
import pycom
import os
import machine
import time

# Try to get version number


class OTA():
    # The following two methods need to be implemented in a subclass for the
    # specific transport mechanism e.g. WiFi

    def connect(self):
        raise NotImplementedError()

    def get_data(self, req, dest_path=None, hash=False):
        raise NotImplementedError()

    # OTA methods

    def get_current_version(self):
        return self.version

    def get_update_manifest(self):
        req = self.endpoint + "/manifest_{}.json".format(self.get_current_version())
        manifest_data = self.get_data(req).decode()
        manifest = ujson.loads(manifest_data)
        gc.collect()
        return manifest

    def update(self, fnwd):
        manifest = self.get_update_manifest()
        if manifest is None:
            print("Already on the latest version.")
            return

        if 'pre' in manifest:
            eval(manifest['pre'])
        # Download new files and verify hashes
        for f in manifest['new'] + manifest['update']:
            # # Upto 5 retries
            # for _ in range(5):
            #     try:
            self.get_file(f)
            #         break
            #     except Exception as e:
            #         print(e)
            #         print("Error downloading `{}` retrying...".format(f['URL']))
            # else:
            #     raise Exception("Failed to download `{}`".format(f['URL']))

        # Backup old files
        # only once all files have been successfully downloaded
        for f in manifest['update']:
            self.backup_file(f)

        # Rename new files to proper name
        for f in manifest['new'] + manifest['update']:
            new_path = "{}.new".format(f['dst_path'])
            dest_path = "{}".format(f['dst_path'])

            os.rename(new_path, dest_path)

        # `Delete` files no longer required
        # This actually makes a backup of the files incase we need to roll back
        for f in manifest['delete']:
            self.delete_file(f)

        # Flash firmware
        if "firmware" in manifest:
            self.write_firmware(manifest['firmware'])


        # Reboot the device to run the new decode
        print("Update complete. Rebooting...")
        time.sleep(2)
        self.imalive(3)
        machine.reset() # just in case imalive custom reset fails

    def get_file(self, f):
        new_path = "{}.new".format(f['dst_path'])

        # If a .new file exists from a previously failed update delete it
        try:
            os.remove(new_path)
        except OSError:
            pass  # The file didnt exist

        # Download new file with a .new extension to not overwrite the existing
        # file until the hash is verified.
        hash = self.get_data(f['URL'].split("/", 3)[-1],
                             dest_path=new_path,
                             hash=True)

        # Hash mismatch
        if hash != f['hash']:
            print(hash, f['hash'])
            msg = "Downloaded file's hash does not match expected hash"
            raise Exception(msg)

    def backup_file(self, f):
        bak_path = "{}.bak".format(f['dst_path'])
        dest_path = "{}".format(f['dst_path'])

        # Delete previous backup if it exists
        try:
            os.remove(bak_path)
        except OSError:
            pass  # There isnt a previous backup

        # Backup current file
        os.rename(dest_path, bak_path)

    def delete_file(self, f):
        bak_path = "/{}.bak_del".format(f)
        dest_path = "/{}".format(f)

        # Delete previous delete backup if it exists
        try:
            os.remove(bak_path)
        except OSError:
            pass  # There isnt a previous delete backup

        # Backup current file if exists
        try:
            os.rename(dest_path, bak_path)
        except OSError:
            pass

    def write_firmware(self, f):
        hash = self.get_data(f['URL'].split("/", 3)[-1],
                             hash=True,
                             firmware=True)
        # TODO: Add verification when released in future firmware


class BaseOTA(OTA):
    def __init__(self, host, port, endpoint, imalive, wd, version="1.0.0"):
        self.ip = socket.getaddrinfo(host, port)[0][4][0]
        self.host = host
        self.port = port
        self.imalive = imalive
        self.wd = wd
        self.version = version
        self.endpoint = endpoint

    def _http_get(self, path, host):
        req_fmt = 'GET /{} HTTP/1.1\r\nHost: {}\r\n\r\n'
        req = bytes(req_fmt.format(path, host), 'utf8')
        return req

    def get_data(self, req, dest_path=None, hash=False, firmware=False):
        # Connect to server
        self.imalive(1)
        self.wd()
        print("Requesting: {}".format(req))
        s = socket.socket(socket.AF_INET,
                          socket.SOCK_STREAM,
                          socket.IPPROTO_TCP)
        s.connect((self.ip, self.port))

        # Request File
        s.sendall(self._http_get(req, self.host))

        try:
            h = uhashlib.sha1()
            content = bytearray()
            fp = None
            if dest_path is not None:
                print(dest_path)
                if firmware:
                    raise Exception("Cannot write firmware to a file")
                try:
                    os.stat(dest_path)
                except:
                    directroiesToMake = dest_path.split("/")[2:-1]
                    for index in range(len(directroiesToMake)):
                        dirms = "/".join(directroiesToMake[0:index+1])
                        try:
                            os.mkdir(dirms)
                            print('Creating directory: ' + dirms)
                        except:
                            pass
                    open(dest_path, 'x')
                fp = open(dest_path, 'wb')

            if firmware:
                pycom.ota_start()



            # Get data from server
            result = s.recv(100)
            start_writing = False
            while (len(result) > 0):
                # Ignore the HTTP headers
                if not start_writing:
                    if "\r\n\r\n" in result:
                        start_writing = True
                        result = result[result.find(b'\r\n\r\n')+4:]

                if start_writing:
                    if firmware:
                        pycom.ota_write(result)
                    elif fp is None:
                        content.extend(result)
                    else:
                        fp.write(result)

                    if hash:
                        h.update(result)

                result = s.recv(100)

            s.close()

            if fp is not None:
                fp.close()
            if firmware:
                pycom.ota_finish()

        except Exception as e:
            # Since only one hash operation is allowed at Once
            # ensure we close it if there is an error
            if h is not None:
                h.digest()
            raise e

        hash_val = ubinascii.hexlify(h.digest()).decode()

        if dest_path is None:
            if hash:
                return (bytes(content), hash_val)
            else:
                return bytes(content)
        elif hash:
            return hash_val
