#!/usr/bin/python3
# Copyright (C) 2020 Jelmer Vernooij <jelmer@jelmer.uk>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

"""Artifacts."""

from typing import Optional, List
import asyncio

from io import BytesIO
import logging
import os
import shutil
import tempfile

from aiohttp import ClientSession, ClientResponseError
from yarl import URL


DEFAULT_GCS_TIMEOUT = 60


class ServiceUnavailable(Exception):
    """The remote server is temporarily unavailable."""


class ArtifactsMissing(Exception):
    """The specified artifacts are missing."""


class ArtifactManager(object):
    """Manage sets of per-run artifacts.

    Artifacts are named files; no other metadata is stored.
    """
    async def store_artifacts(self, run_id: str, local_path: str, names: Optional[List[str]] = None):
        """Store a set of artifacts.

        Args:
          run_id: The run id
          local_path: Local path to retrieve files from
          names: Optional list of filenames in local_path to upload.
            Defaults to all files in local_path.
        """
        raise NotImplementedError(self.store_artifacts)

    async def get_artifact(self, run_id, filename, timeout=None):
        raise NotImplementedError(self.get_artifact)

    async def retrieve_artifacts(
        self, run_id, local_path, filter_fn=None, timeout=None
    ):
        raise NotImplementedError(self.retrieve_artifacts)

    async def iter_ids(self):
        raise NotImplementedError(self.iter_ids)

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc, tb):
        return False


class LocalArtifactManager(ArtifactManager):
    def __init__(self, path):
        self.path = os.path.abspath(path)
        if not os.path.isdir(self.path):
            os.makedirs(self.path)

    def __repr__(self):
        return "%s(%r)" % (type(self).__name__, self.path)

    async def store_artifacts(self, run_id, local_path, names=None, timeout=None):
        run_dir = os.path.join(self.path, run_id)
        try:
            os.mkdir(run_dir)
        except FileExistsError:
            pass
        if names is None:
            names = os.listdir(local_path)
        for name in names:
            shutil.copy(os.path.join(local_path, name), os.path.join(run_dir, name))

    async def iter_ids(self):
        for entry in os.scandir(self.path):
            yield entry.name

    async def delete_artifacts(self, run_id):
        shutil.rmtree(os.path.join(self.path, run_id))

    async def get_artifact(self, run_id, filename, timeout=None):
        return open(os.path.join(self.path, run_id, filename), "rb")

    async def retrieve_artifacts(
        self, run_id, local_path, filter_fn=None, timeout=None
    ):
        run_path = os.path.join(self.path, run_id)
        if not os.path.isdir(run_path):
            raise ArtifactsMissing(run_id)
        for entry in os.scandir(run_path):
            if filter_fn is not None and not filter_fn(entry.name):
                continue
            shutil.copy(entry.path, os.path.join(local_path, entry.name))


class GCSArtifactManager(ArtifactManager):
    def __init__(self, location, creds_path=None, trace_configs=None):
        self.bucket_name = URL(location).host
        self.creds_path = creds_path
        self.trace_configs = trace_configs

    def __repr__(self):
        return "%s(%r)" % (type(self).__name__, "gs://%s/" % self.bucket_name)

    async def __aenter__(self):
        from gcloud.aio.storage import Storage

        self.session = ClientSession(trace_configs=self.trace_configs)
        await self.session.__aenter__()
        self.storage = Storage(service_file=self.creds_path, session=self.session)
        self.bucket = self.storage.get_bucket(self.bucket_name)

    async def __aexit__(self, exc_type, exc, tb):
        await self.session.__aexit__(exc_type, exc, tb)
        return False

    async def store_artifacts(self, run_id, local_path, names=None, timeout=None):
        if timeout is None:
            timeout = DEFAULT_GCS_TIMEOUT
        if names is None:
            names = os.listdir(local_path)
        if not names:
            return
        todo = []
        for name in names:
            todo.append(
                self.storage.upload_from_filename(
                    self.bucket_name,
                    "%s/%s" % (run_id, name),
                    os.path.join(local_path, name),
                    timeout=timeout,
                )
            )
        try:
            await asyncio.gather(*todo)
        except ClientResponseError as e:
            if e.status == 503:
                raise ServiceUnavailable()
            raise
        logging.info(
            "Uploaded %r to run %s in bucket %s.", names, run_id, self.bucket_name
        )

    async def iter_ids(self):
        ids = set()
        for name in await self.bucket.list_blobs():
            log_id = name.split("/")[0]
            if log_id not in ids:
                yield log_id
            ids.add(log_id)

    async def retrieve_artifacts(
        self, run_id, local_path, filter_fn=None, timeout=None
    ):
        if timeout is None:
            timeout = DEFAULT_GCS_TIMEOUT
        names = await self.bucket.list_blobs(prefix=run_id + "/")
        if not names:
            raise ArtifactsMissing(run_id)

        async def download_blob(name):
            with open(os.path.join(local_path, os.path.basename(name)), "wb+") as f:
                f.write(
                    await self.storage.download(
                        bucket=self.bucket_name, object_name=name, timeout=timeout
                    )
                )

        await asyncio.gather(
            *[
                download_blob(name)
                for name in names
                if filter_fn is None or filter_fn(os.path.basename(name))
            ]
        )

    async def get_artifact(self, run_id, filename, timeout=DEFAULT_GCS_TIMEOUT):
        try:
            return BytesIO(
                await self.storage.download(
                    bucket=self.bucket_name,
                    object_name="%s/%s" % (run_id, filename),
                    timeout=timeout,
                )
            )
        except ClientResponseError as e:
            if e.status == 503:
                raise ServiceUnavailable()
            if e.status == 404:
                raise FileNotFoundError
            raise


def get_artifact_manager(location, trace_configs=None):
    if location.startswith("gs://"):
        return GCSArtifactManager(location, trace_configs=trace_configs)
    return LocalArtifactManager(location)


async def list_ids(manager):
    async with manager:
        async for id in manager.iter_ids():
            print(id)


async def upload_backup_artifacts(
    backup_artifact_manager, artifact_manager, timeout=None
):
    async for run_id in backup_artifact_manager.iter_ids():
        with tempfile.TemporaryDirectory() as td:
            await backup_artifact_manager.retrieve_artifacts(
                run_id, td, timeout=timeout
            )
            try:
                await artifact_manager.store_artifacts(run_id, td, timeout=timeout)
            except Exception as e:
                logging.warning(
                    "Unable to upload backup artifacts (%r): %s", run_id, e
                )
            else:
                await backup_artifact_manager.delete_artifacts(run_id)


async def store_artifacts_with_backup(manager, backup_manager, from_dir, run_id, names):
    try:
        await manager.store_artifacts(run_id, from_dir, names)
    except Exception as e:
        logging.warning("Unable to upload artifacts for %r: %r", run_id, e)
        if backup_manager:
            await backup_manager.store_artifacts(run_id, from_dir, names)
            logging.info(
                "Uploading results to backup artifact " "location %r.", backup_manager
            )
        else:
            logging.warning("No backup artifact manager set. ")
            raise


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(dest="command")
    list_parser = subparsers.add_parser("list")
    list_parser.add_argument("location", type=str)
    args = parser.parse_args()
    if args.command == "list":
        manager = get_artifact_manager(args.location)
        asyncio.run(list_ids(manager))
