#!/usr/bin/env python3

import argparse
import dataclasses
import itertools
# Using multiprocessing not for the usual GIL avoidance reasons,
# but because of historical risks of mixing subprocess+threading.
import multiprocessing
import os
import subprocess
import sys
import tempfile
import time

from pathlib import Path
from typing import Optional

@dataclasses.dataclass
class MountState:
    mount_path: Path
    unit_name: str
    ready_flag: multiprocessing.Event
    crash_flag: multiprocessing.Event
    returncode: Optional[int]

    def __init__(self, mount_parent):
        self.mount_path = Path(tempfile.mkdtemp(
            prefix='arv-mount-stress-',
            dir=mount_parent,
        ))
        self.unit_name = f'{self.mount_path.stem}.service'
        self.ready_flag = multiprocessing.Event()
        self.crash_flag = multiprocessing.Event()
        self.returncode = None


def follow_journal(mount_state, journal_fd):
    with open(journal_fd) as journal_out:
        for line in journal_out:
            if line.endswith(' ERROR: Unhandled exception during FUSE operation\n'):
                mount_state.crash_flag.set()
    
def schedule_tries(mount_state, start_sleep=90, sleep_mult=2, stop_sleep=900):
    tries = itertools.count(1)
    sleep_time = start_sleep
    while sleep_time < stop_sleep:
        start_time = time.time()
        yield next(tries)
        if mount_state.crash_flag.is_set():
            break
        else:
            elapsed_time = time.time() - start_time
            time.sleep(max(0, sleep_time - elapsed_time))
            sleep_time *= sleep_mult

def stress_mount(mount_state):
    procs = [
        subprocess.Popen(
            ['ls', '-lR', str(path)],
            stdin=subprocess.DEVNULL,
            stdout=subprocess.DEVNULL,
        ) for path in mount_state.mount_path.iterdir()
    ]
    print("Running", len(procs), "ls processes")
    result = max(proc.wait() for proc in procs)
    print("Stress returncode =", result)
    return result

def clean_mount(mount_state):
    with subprocess.Popen(
        ['fusermount', '-qu', str(mount_state.mount_path)],
        stdin=subprocess.DEVNULL,
    ) as umount_proc:
        try:
            umount_proc.wait(10)
        except subprocess.TimeoutExpired:
            subprocess.run([
                'systemctl', '--user',
                'kill', f'{mount_state.mount_path.stem}.service',
            ])
            umount_proc.wait(20)
    return umount_proc.returncode

def main(arglist=None):
    mount_parent = Path(
        os.environ.get('XDG_RUNTIME_DIR')
        or os.environ.get('TMPDIR')
        or '/tmp'
    )
    mount_state = MountState(mount_parent)
    unit_arg = f'--unit={mount_state.unit_name}'
    subprocess.run([
        'systemd-run', '--user', unit_arg,
        'arv-mount', '--foreground', '--shared',
        f'--directory-cache={2 << 20}',
        '--', str(mount_state.mount_path),
    ], stdin=subprocess.DEVNULL, check=True)
    journal_proc = subprocess.Popen(
        ['journalctl', '--user', unit_arg, '--follow', '--output=cat'],
        stdin=subprocess.DEVNULL,
        stdout=subprocess.PIPE,
    )
    follow_proc = multiprocessing.Process(
        target=follow_journal,
        args=(mount_state, journal_proc.stdout.fileno()),
    )
    follow_proc.start()
    try:
        print("waiting for mount")
        for _ in schedule_tries(mount_state, 1, 2, 60):
            if have_contents := any(mount_state.mount_path.iterdir()):
                break
        assert have_contents, "mount never had contents"
        for _ in schedule_tries(mount_state):
            stress_returncode = stress_mount(mount_state)
            if stress_returncode != os.EX_OK:
                break
    finally:
        journal_proc.terminate()
        umount_returncode = clean_mount(mount_state)
        follow_proc.join()
        mount_state.mount_path.rmdir()
    if stress_returncode != os.EX_OK:
        if sys.stdout.isatty():
            subprocess.run(['journalctl', '--user', '--no-pager', unit_arg])
        return stress_returncode
    else:
        return umount_returncode

if __name__ == '__main__':
    sys.exit(main(sys.argv[1:]))
