#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# «recovery_common» - Misc Functions and variables that are useful in many areas
#
# Copyright (C) 2009-2010, Dell Inc.
# Copyright (C) 2011-2011, Canonical Ltd.
#
# Author:
#  - Mario Limonciello <Mario_Limonciello@Dell.com>
#  - Hsin-Yi Chen <hychen@canonical.com>
#
# This is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free
# Software Foundation; either version 2 of the License, or at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this application; if not, write to the Free Software Foundation, Inc., 51
# Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
##################################################################################

import dbus.mainloop.glib
import subprocess
from gi.repository import GLib
import os
import tempfile
import glob
import sys
import logging
import datetime
import distro
import uuid

from ubunturecovery import disksmgr
from ubunturecovery import metaclass

##                ##
##Common Variables##
##                ##

DBUS_BUS_NAME = 'com.ubuntu.RecoveryMedia'
DBUS_INTERFACE_NAME = 'com.ubuntu.RecoveryMedia'


#Translation Support
DOMAIN = 'ubuntu-recovery'
LOCALEDIR = '/usr/share/locale'

#UI file directory
if os.path.isdir('gtk') and 'DEBUG' in os.environ:
    UIDIR = 'gtk'
else:
    UIDIR = '/usr/share/ubuntu'


#Supported burners and their arguments
DVD_BURNERS = { 'brasero':['-i'],
               'nautilus-cd-burner':['--source-iso='] }
USB_BURNERS = { 'usb-creator':['-n', '--iso'],
                'usb-creator-gtk':['--iso'],
                'usb-creator-kde':['-n', '--iso'] }

##                ##
##Common Functions##
##                ##
def get_part_uuid(part):
    """get uuid of a partition
    """
    blkid = fetch_output(['blkid', part, "-p", "-o", "udev"]).split('\n')
    for item in blkid:
        if item.startswith('ID_FS_UUID'):
            return item.split('=')[1]

def get_memsize():
    """get memory size"""
    mem = 0
    if os.path.exists('/sys/firmware/memmap'):
        for root, dirs, files in os.walk('/sys/firmware/memmap', topdown=False):
            if os.path.exists(os.path.join(root, 'type')):
                with open(os.path.join(root, 'type')) as rfd:
                    type = rfd.readline().strip('\n')
                if type != "System RAM":
                    continue
                with open(os.path.join(root, 'start')) as rfd:
                    start = int(rfd.readline().strip('\n'),0)
                with open(os.path.join(root, 'end')) as rfd:
                    end = int(rfd.readline().strip('\n'),0)
                mem += (end - start + 1)
        mem = float(mem/1024)
    if mem == 0:
        with open('/proc/meminfo','r') as rfd:
            for line in rfd.readlines():
                if line.startswith('MemTotal'):
                    mem = float(line.split()[1].strip())
                    break
    return round(mem/1048575) #in GB

def black_tree(action, blacklist, src, dst='', base=None):
    """Recursively ACTIONs files from src to dest only
       when they don't match the blacklist outlined in blacklist"""
    return _tree(action, blacklist, src, dst, base, False)

def white_tree(action, whitelist, src, dst='', base=None):
    """Recursively ACTIONs files from src to dest only
       when they match the whitelist outlined in whitelist"""
    return _tree(action, whitelist, src, dst, base, True)

def _tree(action, list, src, dst, base, white):
    """Helper function for tree calls"""
    from distutils.file_util import copy_file

    if base is None:
        base = src
        if not base.endswith('/'):
            base += '/'

    names = os.listdir(src)

    if action == "copy":
        outputs = []
    elif action == "size":
        outputs = 0

    for n in names:
        src_name = os.path.join(src, n)
        dst_name = os.path.join(dst, n)
        end = src_name.split(base)[1]

        #don't copy symlinks or hardlinks, vfat seems to hate them
        if os.path.islink(src_name):
            continue

        #recurse till we find FILES
        elif os.path.isdir(src_name):
            if action == "copy":
                outputs.extend(
                    _tree(action, list, src_name, dst_name, base, white))
            elif action == "size":
                #add the directory we're in
                outputs += os.path.getsize(src_name)
                #add the files in that directory
                outputs += _tree(action, list, src_name, dst_name, base, white)

        #only copy the file if it matches the list / color
        elif (white and list.search(end)) or not (white or list.search(end)):
            if action == "copy":
                if not os.path.isdir(dst):
                    os.makedirs(dst)
                copy_file(src_name, dst_name, preserve_mode=1,
                          preserve_times=1, update=1, dry_run=0)
                outputs.append(dst_name)

            elif action == "size":
                outputs += os.path.getsize(src_name)

    return outputs

#@FIXME: remove me
def check_vendor():
    """Checks to make sure that the app is running on Ubuntu HW"""
    return True

def get_pkgversion(package='ubuntu-recovery'):
    """Queries the package management system for the current tool version"""
    try:
        import apt.cache
        cache = apt.cache.Cache()
        if cache[package].is_installed:
            return cache[package].installed.version
    except Exception as msg:
        print("Error checking %s version: %s" % (package, msg),
              file=sys.stderr)
        return "unknown"

def process_conf_file(original, new, uuid, rp_number, ako='', recovery_text=''):
    """Replaces all instances of a partition, OS, and extra in a conf type file
       Generally used for things that need to touch grub"""
    if not os.path.isdir(os.path.split(new)[0]):
        os.makedirs(os.path.split(new)[0])
    import lsb_release
    release = lsb_release.get_distro_information()

    extra_cmdline = ako
    if extra_cmdline:
        #remove any duplicate entries
        ka_list = find_extra_kernel_options().split(' ')
        ako_list = extra_cmdline.split(' ')
        for var in ka_list:
            found = False
            for item in ako_list:
                left = item.split('=')[0].strip()
                if left and left in var:
                    found = True
            #propagate anything but BOOT_IMAGE (it gets added from isolinux)
            if not found and not 'BOOT_IMAGE' in var:
                extra_cmdline += ' ' + var
    else:
        extra_cmdline = find_extra_kernel_options()

    #starting with 10.10, we replace the whole drive string (/dev/sdX,msdosY)
    #earlier releases are hardcoded to (hd0,Y)
    if float(release["RELEASE"]) >= 10.10:
        platinfo = PlatInfo()
        rp_number = platinfo.disk_layout + rp_number

    with open(original, "r", encoding="utf-8") as base:
        with open(new, 'w', encoding='utf-8') as output:
            for line in base.readlines():
                if "#RECOVERY_TEXT#" in line:
                    line = line.replace("#RECOVERY_TEXT#", recovery_text)
                if "#UUID#" in line:
                    line = line.replace("#UUID#", uuid)
                if "#PARTITION#" in line:
                    line = line.replace("#PARTITION#", rp_number)
                if "#OS#" in line:
                    line = line.replace("#OS#", "%s %s" % (release["ID"], release["RELEASE"]))
                if "#EXTRA#" in line:
                    line = line.replace("#EXTRA#", "%s" % extra_cmdline.strip())
                output.write(line)

def fetch_output(cmd, data=None):
    '''Helper function to just read the output from a command'''
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
                                 stderr=subprocess.PIPE,
                                 stdin=subprocess.PIPE)
    if isinstance(data, str):
        data = data.encode()
    (out, err) = proc.communicate(data)
    if isinstance(out, bytes):
        out = out.decode('utf-8')
    if isinstance(err, bytes):
        err = err.decode('utf-8')
    if proc.returncode is None:
        proc.wait()
    if proc.returncode != 0:
        error = "Command %s failed with stdout/stderr: %s\n%s" % (cmd, out, err)
        import syslog
        syslog.syslog(error)
        raise RuntimeError(error)
    return out

def find_supported_ui():
    """Finds potential user interfaces"""
    desktop = { 'gnome'             : 'gnome.desktop',
                'unity-2d'          : 'unity-2d.desktop',
                'gnome-classic'     : 'gnome-classic.desktop',
                'gnome-2d'          : 'gnome-2d.desktop'}
    name =    { 'gnome'             : 'Unity (3D)',
                'unity-2d'          : 'Unity (2D)',
                'gnome-classic'     : 'GNOME (Classic 3D)',
                'gnome-2d'          : 'GNOME (Classic 2D)'}
    for item in desktop:
        if not os.path.exists(os.path.join('/usr/share/xsessions/', desktop[item])):
            name.pop(item)
    return name

def find_extra_kernel_options():
    """Finds any extra kernel command line options"""
    with open('/proc/cmdline', 'r') as cmdline:
        cmd = cmdline.readline().strip().split('--')
    if len(cmd) > 1:
        return cmd[1].strip()
    else:
        return ''

def find_partition(label):
    """Searching disk partition according to label name."""
    device = disksmgr.find_partition(label)
    if device:
        return device.devicefile
    else:
        return None

def find_burners():
    """Checks for what utilities are available to burn with"""
    def which(program):
        """Emulates the functionality of the unix which command"""
        def is_exe(fpath):
            """Determines if a filepath is executable"""
            return os.path.exists(fpath) and os.access(fpath, os.X_OK)

        fpath = os.path.split(program)[0]
        if fpath:
            if is_exe(program):
                return program
        else:
            for path in os.environ["PATH"].split(os.pathsep):
                exe_file = os.path.join(path, program)
                if is_exe(exe_file):
                    return exe_file

        return None

    def find_command(array):
        """Determines if a command listed in the array is valid"""
        for item in array:
            path = which(item)
            if path is not None:
                return [path] + array[item]
        return None

    dvd = find_command(DVD_BURNERS)
    usb = find_command(USB_BURNERS)

    #If we have apps for DVD burning, check hardware
    if dvd:
        found_supported_dvdr = False
        try:
            bus = dbus.SystemBus()
            #first try to use udisks, if this fails, fall back to devkit-disks.
            obj = bus.get_object('org.freedesktop.UDisks', '/org/freedesktop/UDisks')
            iface = dbus.Interface(obj, 'org.freedesktop.UDisks')
            devices = iface.EnumerateDevices()
            for device in devices:
                obj = bus.get_object('org.freedesktop.UDisks', device)
                dev = dbus.Interface(obj, 'org.freedesktop.DBus.Properties')

                supported_media = dev.Get('org.freedesktop.UDisks.Device', 'DriveMediaCompatibility')
                for item in supported_media:
                    if 'optical_dvd_r' in item:
                        found_supported_dvdr = True
                        break
                if found_supported_dvdr:
                    break
            if not found_supported_dvdr:
                dvd = None
            return (dvd, usb)
        except dbus.DBusException as msg:
            print("%s, UDisks Failed burner parse" % str(msg))
        try:
            #first try to use devkit-disks. if this fails, then, it's OK
            obj = bus.get_object('org.freedesktop.DeviceKit.Disks', '/org/freedesktop/DeviceKit/Disks')
            iface = dbus.Interface(obj, 'org.freedesktop.DeviceKit.Disks')
            devices = iface.EnumerateDevices()
            for device in devices:
                obj = bus.get_object('org.freedesktop.DeviceKit.Disks', device)
                dev = dbus.Interface(obj, 'org.freedesktop.DBus.Properties')

                supported_media = dev.Get('org.freedesktop.DeviceKit.Disks.Device', 'DriveMediaCompatibility')
                for item in supported_media:
                    if 'optical_dvd_r' in item:
                        found_supported_dvdr = True
                        break
                if found_supported_dvdr:
                    break
            if not found_supported_dvdr:
                dvd = None
        except dbus.DBusException as msg:
            print("%s, device kit Failed burner parse" % str(msg))

    return (dvd, usb)

def match_system_device(bus, vendor, device):
    '''Attempts to match the vendor and device combination  on the specified bus
       Allows the following formats:
       base 16 int (eg 0x1234)
       base 16 int in a str (eg '0x1234')
    '''
    def recursive_check_ids(directory, cvendor, cdevice, depth=1):
        """Recurses into a directory to check all files in that directory"""
        vendor = device = ''
        for root, dirs, files in os.walk(directory, topdown=True):
            for fname in files:
                if not vendor and (fname == 'idVendor' or fname == 'vendor'):
                    with open(os.path.join(root, fname), 'r') as filehandle:
                        vendor = filehandle.readline().strip('\n')
                    if len(vendor) > 4 and '0x' not in vendor:
                        vendor = ''
                elif not device and (fname == 'idProduct' or fname == 'device'):
                    with open(os.path.join(root, fname), 'r') as filehandle:
                        device = filehandle.readline().strip('\n')
                    if len(device) > 4 and '0x' not in device:
                        device = ''
            if vendor and device:
                if ( int(vendor, 16) == int(cvendor)) and \
                   ( int(device, 16) == int(cdevice)) :
                    return True
                else:
                    #reset devices so they aren't checked multiple times needlessly
                    vendor = device = ''
            if not files:
                if depth > 0:
                    for directory in [os.path.join(root, d) for d in dirs]:
                        if recursive_check_ids(directory, cvendor, cdevice, depth-1):
                            return True
        return False

    if bus != "usb" and bus != "pci":
        return False

    if type(vendor) == str and '0x' in vendor:
        vendor = int(vendor, 16)
    if type(device) == str and '0x' in device:
        device = int(device, 16)

    return recursive_check_ids('/sys/bus/%s/devices' % bus, vendor, device)

def walk_cleanup(directory):
    """Walks a directory, removes all files, and removes that directory"""
    if os.path.exists(directory):
        for root, dirs, files in os.walk(directory, topdown=False):
            for name in files:
                os.remove(os.path.join(root, name))
            for name in dirs:
                full_name = os.path.join(root, name)
                if os.path.islink(full_name):
                    os.remove(full_name)
                elif os.path.isdir(full_name):
                    os.rmdir(full_name)
                #covers broken links
                else:
                    os.remove(full_name)
        os.rmdir(directory)

def create_new_uuid(old_initrd_directory, old_casper_directory,
                    new_initrd_directory, new_casper_directory):
    """ Regenerates the UUID contained in a casper initramfs
        Returns full path of the old initrd and casper files (for blocklisting)
    """
    tmpdir = tempfile.mkdtemp()

    #Detect the old initramfs stuff
    try:
        old_initrd_file = glob.glob('%s/initrd*' % old_initrd_directory)[0]
    except Exception as msg:
        logging.warning("create_new_uuid: %s" % str(msg))
        raise dbus.DBusException("Missing initrd in image.")
    try:
        old_uuid_file   = glob.glob('%s/casper-uuid*' % old_casper_directory)[0]
    except Exception as msg:
        logging.warning("create_new_uuid: Old casper UUID not found, assuming 'casper-uuid': %s" % msg)
        old_uuid_file   = '%s/casper-uuid' % old_casper_directory

    if not old_initrd_file or not old_uuid_file:
        raise dbus.DBusException("Unable to detect valid initrd.")

    logging.debug("create_new_uuid: old initrd %s, old uuid %s" %
                 (old_initrd_file, old_uuid_file))

    #Extract old initramfs with the new format
    chain0 = subprocess.Popen(["/usr/bin/unmkinitramfs", old_initrd_file, "."],
                            stdout=subprocess.PIPE, cwd=tmpdir)
    chain0.communicate()

    #Generate new UUID
    new_uuid_file = os.path.join(new_casper_directory,
                                 os.path.basename(old_uuid_file))
    logging.debug("create_new_uuid: new uuid file: %s" % new_uuid_file)
    new_uuid = str(uuid.uuid4())
    logging.debug("create_new_uuid: new UUID: %s" % new_uuid)
    initramfs_root = os.path.join(tmpdir, 'main')
    if not os.path.exists(initramfs_root):
        initramfs_root = tmpdir
    for item in [new_uuid_file, os.path.join(initramfs_root, 'conf', 'uuid.conf')]:
        with open(item, "w") as uuid_fd:
            uuid_fd.write("%s\n" % new_uuid)

    #Add bootstrap to initrd
    chain0 = subprocess.Popen(['/usr/share/ubuntu/casper/hooks/ubuntu-bootstrap'], env={'DESTDIR': initramfs_root, 'INJECT': '1'})
    chain0.communicate()

    #Detect compression
    lines = ''
    root = os.path.join(tmpdir, 'main', 'conf', 'initramfs.conf')
    with open(root, 'r') as rfd:
        lines = rfd.readlines()
    new_compression = ''
    for line in lines:
        if line.startswith('COMPRESS='):
            components = line.split('=')
            if len(components) > 1:
                new_compression = components[1].strip()

    if new_compression == "gzip":
        compress_command = ["gzip", "-n"]
    elif new_compression == 'lzma' or new_compression == "xz":
        compress_command = ["xz", "--check=crc32"]
    elif new_compression == "lz4":
        compress_command = ["lz4", "-9", "-l"]
    elif new_compression == 'zstd':
        compress_command = ['zstd', '-q', '-1', '-T0']
    logging.debug("create_new_uuid: compression detected: %s" % new_compression)
    logging.debug("create_new_uuid: compression command: %s" % compress_command)

    #Generate new initramfs
    new_initrd_file = os.path.join(new_initrd_directory, 'initrd')
    logging.debug("create_new_uuid: new initrd file: %s" % new_initrd_file)
    
    #Due to it's append mode below, so we should remove old target initrd file first
    print("Remove old initrd file: %s" % new_initrd_file)
    os.remove(new_initrd_file)

    # make the early and late sections separately
    for component in ['early', 'early2', 'main']:
        root = os.path.join(tmpdir, component)
        if not os.path.exists (root):
            continue
        chain0 = subprocess.Popen(['find'], cwd=root,
                                stdout=subprocess.PIPE)
        chain1 = subprocess.Popen(['cpio', '--quiet', '-o', '-H', 'newc'],
                                cwd=root, stdin=chain0.stdout,
                                stdout=subprocess.PIPE)
        with open(new_initrd_file, 'ab') as initrd_fd:
            if component == 'main':
                chain2 = subprocess.Popen(compress_command,
                                        stdin=chain1.stdout,
                                        stdout=subprocess.PIPE)
                initrd_fd.write(chain2.communicate()[0])
            else:
                initrd_fd.write(chain1.communicate()[0])

    walk_cleanup(tmpdir)

    return (old_initrd_file, old_uuid_file)

def dbus_sync_call_signal_wrapper(dbus_iface, func, handler_map, *args, **kwargs):
    '''Run a D-BUS method call while receiving signals.

    This function is an Ugly Hack™, since a normal synchronous dbus_iface.fn()
    call does not cause signals to be received until the method returns. Thus
    it calls func asynchronously and sets up a temporary main loop to receive
    signals and call their handlers; these are assigned in handler_map (signal
    name → signal handler).
    '''
    if not hasattr(dbus_iface, 'connect_to_signal'):
        # not a D-BUS object
        return getattr(dbus_iface, func)(*args, **kwargs)

    def _h_reply(*args, **kwargs):
        """protected method to send a reply"""
        global _h_reply_result
        _h_reply_result = args
        loop.quit()

    def _h_error(exception=None):
        """protected method to send an error"""
        global _h_exception_exc
        _h_exception_exc = exception
        loop.quit()

    loop = GLib.MainLoop()
    global _h_reply_result, _h_exception_exc
    _h_reply_result = None
    _h_exception_exc = None
    kwargs['reply_handler'] = _h_reply
    kwargs['error_handler'] = _h_error
    kwargs['timeout'] = 86400
    for signame, sighandler in handler_map.items():
        dbus_iface.connect_to_signal(signame, sighandler)
    dbus_iface.get_dbus_method(func)(*args, **kwargs)
    loop.run()
    if _h_exception_exc:
        raise _h_exception_exc
    return _h_reply_result


##                ##
## Common Classes ##
##                ##

class RestoreFailed(dbus.DBusException):
    """Exception Raised if the restoration process failed for any reason"""
    _dbus_error_name = 'com.ubuntu.RecoveryMedia.RestoreFailedException'

class CreateFailed(dbus.DBusException):
    """Exception Raised if the media creation process failed for any reason"""
    _dbus_error_name = 'com.ubuntu.RecoveryMedia.CreateFailedException'

class PermissionDeniedByPolicy(dbus.DBusException):
    """Exception Raised if policy kit denied the user access"""
    _dbus_error_name = 'com.ubuntu.RecoveryMedia.PermissionDeniedByPolicy'

class BackendCrashError(SystemError):
    """Exception Raised if the backend crashes"""
    pass

class Seed(object):
    def __init__(self, path):
        self.path = path
        self._keys = {}
        self.parse()

    def keys(self):
        return self._keys

    def parse(self):
        """Parses a preseed file and returns a set of keys"""
        if not os.path.exists(self.path):
            return {}
        with open(self.path, 'r') as rfd:
            line = rfd.readline()
            while line:
                line = line.strip()
                if line and not line.startswith('#'):
                    line = line.split()
                    line.pop(0) # ubiquity or d-i generally
                    key = line.pop(0)
                    if '/' in key:
                        value = " ".join(line[1:])
                        self._keys[key] = value
                line = rfd.readline()

    def set(self, key, val):
        self._keys[key] = val

    def setkeys(self, keys):
        for k,v in keys.items():
            self.set(k, v)

    def get(self, key):
        try:
            return self._keys[key]
        except KeyError:
            return None

    def save(self):
        """Writes out a preseed file with a selected set of keys"""
        with open(self.path, 'w') as wfd:
            wfd.write("# Ubuntu Recovery configuration preseed\n")
            wfd.write("# Last updated on %s\n" % datetime.date.today())
            wfd.write("\n")
            for item in self._keys:
                if self._keys[item] == 'true' or self._keys[item] == 'false':
                    type = 'boolean'
                else:
                    type = 'string'
                wfd.write(" ubiquity %s %s %s\n" % (item, type, self._keys[item]))

class Values(object, metaclass=metaclass.DataObjectType):
    pass

class PlatInfo(object, metaclass=metaclass.Singleton):
    """Platform Informations Data Object
    """

    def __init__(self):
        self.memsize = get_memsize() #GB
        # check efi
        self.dist = distro.linux_distribution(full_distribution_name=False)
        self.isefi = os.path.isdir('/proc/efi') or \
                     os.path.isdir('/sys/firmware/efi')
        self.disk_layout = self.isefi and 'gpt' or 'msdos'

        self._cmdline = []

    def get_cmdline_opt(self, optname):
        optname = optname+'='
        for item in self.cmdline:
            if optname in item:
                try:
                    return item.replace(optname, '')
                except IndexError:
                    return None

    @property
    def cmdline(self):
        if not self._cmdline:
            with open ('/proc/cmdline', 'r') as rfd:
                self._cmdline = rfd.readline().split()
        return self._cmdline
