#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (C) 2019-2024 fjord-technologies
# SPDX-License-Identifier: GPL-3.0-or-late
"""
fd-svc-deploy
"""

import argparse
import fnmatch
import json
import locale
import logging
import os
import re
import subprocess
import sys

from collections import OrderedDict
from packaging.version import Version, parse, InvalidVersion

import redis

from ansible.inventory.manager import split_host_pattern
from dialog import Dialog
from sonicprobe import helpers

import boto3
from botocore.exceptions import ClientError

SYSLOG_NAME = "fd-svc-deploy"
LOG         = logging.getLogger(SYSLOG_NAME)

locale.setlocale(locale.LC_ALL, '')


APTLY_REPO_DISTRIBS = ('stretch',
                       'buster',
                       'bullseye',
                       'bookworm',
                       'trixie')
APTLY_REPO_VERSION  = {'branch': 'devel-%s-fd',
                       'tag': '%s-fd'}
APTLY_REPO_SEARCH   = ('sudo',
                       '-u',
                       'aptly',
                       '/usr/bin/aptly',
                       '-config=/etc/aptly-private.conf',
                       'repo',
                       'search')
ANSIPLAY            = ('fd-ansible',
                       'play')
ANSIBLE_HOSTS_LIST  = ('ansible', 'FjordSVC', '--list-hosts')
ANSIBLE_CHECK_GROUP = ('ansible-inventory', '--graph')
ANSIBLE_FETCH_SVCS  = ('ansible', '-m', 'debug', '-a', 'var=fjord_svc')

ANSIBLE_PB_PATH     = "/etc/ansible/playbooks"
ANSIBLE_PB_SUFFIX   = "%s-install.yml"

PKG_PREFIX          = "fjord-svc-"

EXCLUDE_PKG         = re.compile(r'(_source$|_0\.0\.0)').search
PKG_PARSE_VERSION = {
    'branch': re.compile(r'^[a-z0-9\-]+_([a-zA-Z0-9\.\+\-]+(?:\.[a-zA-Z0-9\+\-]+)*(?:\+[^+]+)?)(?=\+[^+]+$)').match,
    'tag':    re.compile(r'^[a-z0-9\-]+_([a-zA-Z0-9\.\+\-]+(?:\.[a-zA-Z0-9\+\-]+)*(?:\+[^+]+)?)(?=\+[^+]+$)').match,
 }
PARSE_HOSTGROUPS    = re.compile(r'^([a-z0-9\-]+)-(?:[0-9]+|[0-9]+[a-z]+)(\.[a-z0-9\.\-]+)$').match
NB_MAX_VERSIONS     = 25
DOCKER_SERVICES = {
    "cookies-sync",
    "es-bulk",
    "reconciliation",
    "recrypt",
    "segment-query-generator",
    "user",
    "user-seed",
    "user-refresh",
    "fuse-stats-updater"
}

DOCKER_SERVICES_FJORD = {"runner"}

def argv_parse_check():
    """
    Parse (and check a little) command line parameters
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("-C",
                        action    = 'store_true',
                        dest      = 'check',
                        default   = False,
                        help      = "Don't make any changes; instead, try to predict some of the changes that may occur")
    parser.add_argument("-d",
                        dest      = 'distrib',
                        default   = 'stretch',
                        choices   = APTLY_REPO_DISTRIBS,
                        help      = "Available distributions: " + ", ".join(APTLY_REPO_DISTRIBS) + ", instead of %(default)s")
    parser.add_argument("-l",
                        action    = 'append',
                        dest      = 'subset',
                        default   = [],
                        help      = "Further limit selected hosts to an additional pattern")
    parser.add_argument("--loglevel",
                        dest      = 'loglevel',
                        default   = 'info', # warning: see affectation under
                        choices   = ('critical', 'error', 'warning', 'info', 'debug'),
                        help      = ("Emit traces with LOGLEVEL details, must be one of:\t"
                                     "critical, error, warning, info, debug"))
    parser.add_argument("-m",
                        dest      = 'mode',
                        default   = 'group',
                        choices   = ('host', 'group', 'refresh-cache'),
                        help      = "Mode: host, group or refresh-cache, instead of %(default)s")
    parser.add_argument("--nb-max-versions",
                        dest      = 'nb_max_versions',
                        type      = int,
                        default   = NB_MAX_VERSIONS,
                        help      = "Display number of versions, instead of %(default)s")
    parser.add_argument("--no-maint",
                        action    = 'store_true',
                        dest      = 'nomaint',
                        default   = False,
                        help      = "Disable maintenance mode")
    parser.add_argument("--serial",
                        dest      = 'serial',
                        type      = int,
                        default   = '1',
                        help      = "How many hosts deploy in a single time, instead of %(default)s")
    parser.add_argument("--redis",
                        dest      = 'redis',
                        default   = 'redis://127.0.0.1:6379/7',
                        help      = "Redis default server URL, instead of %(default)s")
    parser.add_argument("--ref",
                        dest      = 'reference',
                        default   = 'tag',
                        choices   = ('branch', 'tag'),
                        help      = "Reference: branch or tag, instead of %(default)s")
    parser.add_argument("-t",
                        action    = 'store_true',
                        dest      = 'test',
                        default   = False,
                        help      = "Run motd role for testing purpose")
    parser.add_argument("--tag-pkg",
                        action    = 'append',
                        dest      = 'tag_pkg',
                        default   = [],
                        help      = "List of tag package to rewrite")
    parser.add_argument("-v",
                        action    = 'count',
                        dest      = 'verbose',
                        default   = 0,
                        help      = "verbose mode (-vvv for more, -vvvv to enable connection debugging)")

    args          = parser.parse_args()
    args.loglevel = getattr(logging, args.loglevel.upper(), logging.INFO)

    return args


class FdSVCDeployExit(SystemExit):
    pass

class FdSVCDeployExitNoClear(FdSVCDeployExit):
    pass

class FdSVCDeployHost(str):
    pass

class FdSVCDeployHostGroup(str):
    pass


class FdSVCDeploy(object):
    def __init__(self, options):
        self.options   = options
        self._subset   = []
        self._slt_pkgs = {}
        self._slt_vers = {}

        if self.options.subset:
            for subset in split_host_pattern(self.options.subset):
                self._subset.append(self._subset_pattern(subset))

        self.dialog  = Dialog(dialog='dialog', autowidgetsize = True)
        self.dialog.set_background_title("fjord-svc packages deployment")

    @staticmethod
    def _subset_pattern(pattern_str):
        if not pattern_str.startswith('~'):
            return re.compile(fnmatch.translate(pattern_str)).match

        return re.compile(pattern_str[1:]).match

    def _match_host(self, host):
        if not self._subset:
            return True

        for pattern in self._subset:
            if pattern(host):
                return True

        return False

    @staticmethod
    def _find_playbook_filepath(grp):
        # Supprimer le suffixe de région s’il existe
        parts = grp.split('.')
        if len(parts) > 3:
            grp = '.'.join(parts[:-1])

        dirs = grp.split('.')
        dirs.reverse()

        xfile = ANSIBLE_PB_SUFFIX % dirs.pop(-1)
        cur   = ANSIBLE_PB_PATH

        while dirs:
            x = os.path.join(cur, os.path.sep.join(dirs))
            if not os.path.isdir(x):
                dirs.pop(-1)
                continue

            f = os.path.join(x, xfile)
            if os.path.isfile(f):
                return f

            dirs.pop(-1)

        LOG.error("unable to find ansible playbook file: '%s'", xfile)
        raise FdSVCDeployExitNoClear(2)

    def _get_cache_groups(self):
        r = None
        groups = None

        try:
            r = redis.from_url(self.options.redis)
            groups = r.get('fd-svc-deploy:groups')
            if not groups:
                return None

            groups = json.loads(groups)
            if not groups or not isinstance(groups, dict):
                return None
        except Exception:
            pass
        finally:
            if r and r.connection_pool:
                r.connection_pool.disconnect()

        return groups

    def _set_cache_groups(self, groups):
        if not groups or not isinstance(groups, dict):
            return None

        r = None

        try:
            r = redis.from_url(self.options.redis)
            r.set('fd-svc-deploy:groups', json.dumps(groups))
        except Exception:
            pass
        finally:
            if r and r.connection_pool:
                r.connection_pool.disconnect()

        return None

    def _groups_list(self, from_cache = True):
        choices   = []
        tmpgrps   = {}
        cachegrps = {}

        if from_cache:
            cachegrps = self._get_cache_groups() or {}

        s = subprocess.check_output(ANSIBLE_HOSTS_LIST)
        if not s:
            return None

        import json
        inv = subprocess.check_output(('ansible-inventory', '--list'))
        if not inv:
            return None

        inv_json = json.loads(inv)
        all_groups = inv_json.keys()

        base_groups = set()
        for l in s.splitlines():
            l = l.strip().decode('utf-8')
            if not l or l.startswith("hosts ("):
                continue
            m = PARSE_HOSTGROUPS(l)
            if not m:
                continue
            base_groups.add(''.join(m.groups()))

        for base in sorted(base_groups):
            found_subgroups = False
            for grp in sorted(all_groups):
                if grp.startswith(base + ".") and self._match_host(grp):
                    if grp in inv_json and grp not in tmpgrps:
                        tmpgrps[grp] = 1
                        choices.append((grp, "", False))
                        found_subgroups = True

            if not found_subgroups and self._match_host(base):
                if base in inv_json and base not in tmpgrps:
                    tmpgrps[base] = 1
                    choices.append((base, "", False))

        self._set_cache_groups(tmpgrps)

        return choices

    def _ecr_list_tags(self, repository, exclude_arm=False):
        region = os.environ.get("AWS_DEFAULT_REGION", "eu-west-3")
        ecr = boto3.client("ecr", region_name=region)

        semver_tags = []
        has_latest = False

        try:
            paginator = ecr.get_paginator("describe_images")
            for page in paginator.paginate(repositoryName=repository):
                for image in page.get("imageDetails", []):
                    pushed_at = image.get("imagePushedAt")
                    for tag in image.get("imageTags", []):

                        if tag == "latest":
                            has_latest = True
                            continue

                        if exclude_arm and 'arm' in tag.lower():
                            continue

                        try:
                            parsed = parse(tag)
                            display = tag
                        except InvalidVersion:
                            try:
                                # ECR replaces '+' with '-' in tag names; reverse that for parsing
                                tag_for_parse = re.sub(r'-([a-zA-Z][a-zA-Z0-9]*)$', r'+\1', tag)
                                parsed = parse(tag_for_parse)
                                display = tag_for_parse
                            except InvalidVersion:
                                continue

                        semver_tags.append((display, tag, parsed, pushed_at))

        except ClientError as e:
            LOG.error("ECR error: %s", e)
            raise FdSVCDeployExit(2)

        semver_tags = sorted(semver_tags, key=lambda x: x[2], reverse=True)
        semver_tags = semver_tags[:self.options.nb_max_versions]

        out = []
        if has_latest:
            out.append(("latest", "latest"))
        out += [(t[0], t[1]) for t in semver_tags]
        return out

    def _tag_package(self, tag):
        if not self.options.tag_pkg or not tag:
            return tag

        for tag_pkg in self.options.tag_pkg:
            if not tag_pkg.startswith("%s:" % tag):
                continue

            t = tag_pkg.split(':', 1)[1]
            if t:
                return t

            LOG.warning("invalid tag package: %s", tag_pkg)

        return tag

    def hosts(self):
        s = subprocess.check_output(ANSIBLE_HOSTS_LIST)
        if not s:
            raise FdSVCDeployExit(2)

        choices = []
        for host in s.splitlines():
            host = host.strip().decode('utf-8')
            if host and not host.startswith("hosts (") and self._match_host(host):
                choices.append((host, "", False))

        if not choices:
            raise FdSVCDeployExit(2)

        while True:
            code, r = self.dialog.checklist("Select hosts:",
                                            choices = choices,
                                            width   = 80)
            if code == self.dialog.CANCEL:
                raise FdSVCDeployExit(1)

            if r:
                return map(FdSVCDeployHost, r)

        raise FdSVCDeployExit(2)

    def groups(self):
        choices = self._groups_list()
        if not choices:
            raise FdSVCDeployExit(2)

        while True:
            code, r = self.dialog.checklist("Select groups:",
                                            choices = choices,
                                            width   = 80)
            if code == self.dialog.CANCEL:
                raise FdSVCDeployExit(1)

            if r:
                return map(FdSVCDeployHostGroup, r)

        raise FdSVCDeployExit(2)

    def pkgs(self, hosts):
        r = OrderedDict()

        for x in hosts:
            while True:
                s = subprocess.check_output(ANSIBLE_FETCH_SVCS + (x,))
                if not s:
                    LOG.error("unable to find Fjord SVC for host: %r", x)
                    continue

                o = []
                has_success = False
                for i, l in enumerate(s.splitlines()):
                    l = l.strip().decode('utf-8')
                    if l.endswith(" SUCCESS => {"):
                        has_success = True
                        if i > 0:
                            o.append(",")
                        o.append("{")
                    elif has_success:
                        o.append(l)
                        if l == "}":
                            has_success = False

                if not o:
                    continue

                j = {}

                for n in json.loads(''.join(['['] + o + [']'])):
                    j = helpers.merge(j, n)

                if not j.get('fjord_svc'):
                    continue

                choices = []

                for c in sorted(j['fjord_svc']):
                    choices.append((c, "", self._slt_pkgs.get(c) or False))

                code, pkgs = self.dialog.checklist("Select packages:",
                                                   title   = x,
                                                   choices = choices,
                                                   width   = 80)
                if code == self.dialog.CANCEL:
                    raise FdSVCDeployExit(1)

                if pkgs:
                    r[x] = dict(zip(pkgs, [''] * len(pkgs)))
                    self._slt_pkgs = dict(zip(pkgs, [True] * len(pkgs)))
                    break

        return r

    def pkgs_version(self, hosts):
        r = OrderedDict()

        for host, pkgs in hosts.items():
            for pkg in pkgs.keys():

                # 🔥 Si service docker → ECR
                if pkg in DOCKER_SERVICES:

                    repository = f"services/{pkg}"
                    is_arm = 'arm' in host.lower()
                    versions = self._ecr_list_tags(repository, exclude_arm=not is_arm)

                    if not versions:
                        LOG.error("No Docker tags found for %s", pkg)
                        continue

                    ecr_tag_map = {display: ecr_tag for display, ecr_tag in versions}
                    choices = [(display, "") for display, _ in versions]

                    code, selected = self.dialog.menu(
                        f"Select Docker version for {pkg}:",
                        title=host,
                        choices=choices,
                        width=80
                    )

                    if code == self.dialog.CANCEL:
                        raise FdSVCDeployExit(1)

                    if not host in r:
                        r[host] = OrderedDict()

                    r[host][pkg] = ecr_tag_map.get(selected, selected)
                    continue

                while True:
                    pkg_versions = []

                    s = subprocess.check_output(APTLY_REPO_SEARCH \
                          + (APTLY_REPO_VERSION[self.options.reference] % self.options.distrib,
                            "%s%s" % (PKG_PREFIX, pkg),))
                    if not s:
                        LOG.error("unable to find package Fjord SVC: %r", pkg)
                        continue

                    for l in s.splitlines():
                        l = l.strip().decode('utf-8')
                        m = EXCLUDE_PKG(l)
                        if m:
                            continue

                        m = PKG_PARSE_VERSION[self.options.reference](l)
                        if m:
                            full_version = m.group(1)

                            try:
                                pkg_versions.append((full_version, Version(full_version)))
                            except InvalidVersion:
                                pass

                    choices = []

                    pkg_versions = sorted(pkg_versions, key=lambda x: x[1], reverse=True)

                    for i, (full_version, _) in enumerate(pkg_versions):
                        if i <= self.options.nb_max_versions:
                            choices.append((full_version, ""))

                    code, version = self.dialog.menu("Select %s version:" % pkg,
                                                     title        = host,
                                                     choices      = choices,
                                                     default_item = self._slt_vers.get(pkg, ''),
                                                     width        = 80)
                    if code == self.dialog.CANCEL:
                        raise FdSVCDeployExit(1)

                    if not version:
                        continue

                    if not host in r:
                        r[host] = OrderedDict()

                    r[host][pkg] = version
                    self._slt_vers[pkg] = version
                    break

        return r

    def summary(self, tasks):
        txt = ["Are you sure to proceed?", ""]
        for host, pkgs in tasks.items():
            txt.append("%s:" % host)
            for pkg, ver in pkgs.items():
                txt.append("    %s = %s" % (pkg, ver))

        code = self.dialog.yesno("\n".join(txt),
                                 title  = "Installation summary",
                                 width  = 80,
                                 height = 20)

        if code != self.dialog.OK:
            raise FdSVCDeployExit(1)

    def run_tasks(self, tasks):
        self.dialog.clear()
        os.system('clear')

        args = ANSIPLAY

        env = os.environ.copy()
        env['ANSIBLE_ANY_ERRORS_FATAL'] = 'True'

        for host, pkgs in tasks.items():
            if isinstance(host, FdSVCDeployHost):
                grp = ''.join(PARSE_HOSTGROUPS(host).groups())
            else:
                grp = host

            playbook = self._find_playbook_filepath(grp)

            ansiargs = args + (playbook, '-l', host)

            if self.options.verbose:
                ansiargs += ('-' + self.options.verbose * 'v',)
            if self.options.check:
                ansiargs += ('-C',)

            tags = ['motd']

            if not self.options.test and not self.options.nomaint:
                tags.append('automaint')

            eargs = {'fd_ansiplay_serial': str(self.options.serial),
                     'FD_SVC_DEPLOY': True,
                     'FD_SVC_DEPLOY_MODE': self.options.mode,
                     'FD_SVC_DEPLOY_TEST': self.options.test,
                     'fjord_svc_version':  {}}

            if any(pkg in DOCKER_SERVICES for pkg in pkgs):
                eargs['fjord_svc_image_repo'] = \
                    "929885317002.dkr.ecr.eu-west-3.amazonaws.com/services"

            for pkg, version in pkgs.items():
                if not self.options.test:
                    tags.append(self._tag_package(pkg))
                eargs['fjord_svc_version'][pkg] = version

            subprocess.check_call(ansiargs + ('--tags', ','.join(tags), '-e', json.dumps(eargs)),
                                  env = env)

    def do_host(self):
        tasks = self.pkgs_version(
            self.pkgs(self.hosts()))
        self.summary(tasks)
        self.run_tasks(tasks)

    def do_group(self):
        tasks = self.pkgs_version(
            self.pkgs(self.groups()))
        self.summary(tasks)
        self.run_tasks(tasks)

    def do_refresh_cache(self):
        LOG.info("refreshing groups cache")
        self._groups_list(from_cache = False)


def main(options):
    """
    Main function
    """

    xformat = "%(levelname)s:%(asctime)-15s: %(message)s"
    datefmt = '%Y-%m-%d %H:%M:%S'
    logging.basicConfig(level   = options.loglevel,
                        format  = xformat,
                        datefmt = datefmt)

    rc            = 1
    clear         = True
    log_error     = False
    err           = None
    fd_svc_deploy = None

    try:
        fd_svc_deploy = FdSVCDeploy(options)
        rc            = getattr(fd_svc_deploy, "do_%s" % options.mode.replace('-', '_'))()
    except FdSVCDeployExitNoClear as e:
        clear = False
        err = e
        rc = e.code
    except FdSVCDeployExit as e:
        rc = e.code
        err = e
    except subprocess.CalledProcessError as e:
        clear = False
        err = e
        log_error = LOG.error
        rc = e.returncode
    except (SystemExit, KeyboardInterrupt) as e:
        err = e
        rc = 255
    except Exception as e:
        err = e
        log_error = LOG.exception
        rc = 5
    else:
        err = None

    if clear and rc is not None and rc > 0:
        if fd_svc_deploy:
            fd_svc_deploy.dialog.clear()
        os.system('clear')

    if log_error and err:
        log_error(err)

    return rc


if __name__ == '__main__':
    sys.exit(main(argv_parse_check()))