#!/usr/bin/env python3
# -----------------------------------------------------------------------------
# voxl-wg-config-api — drone-side HTTP service for safe remote config edits.
#
# This service exists so the dashboard can read and write a small,
# whitelisted set of /etc/modalai/* config files (and restart their
# matching systemd units) without needing SSH access to the drone.
#
# Why a REST service instead of SSH:
#   - No password handling on the dashboard side.
#   - No SSH handshake overhead (one HTTP round-trip per op instead of
#     ~1-2s of SSH negotiation).
#   - Narrower attack surface: this service can only touch files in the
#     whitelist and only restart units in the whitelist. Shell access
#     gives you the whole drone.
#   - Easier to reason about: explicit JSON in/out, explicit status codes,
#     no shell quoting concerns.
#
# Security model:
#   - Bound to the wg0 IP only. Not exposed on the drone's LTE / WiFi /
#     ethernet interfaces. If you're on the WireGuard VPN, you can talk
#     to this; otherwise the port doesn't exist for you.
#   - No application-level auth in v1 — VPN reachability IS the auth.
#     Same security stance as the previous SSH-based approach: anyone on
#     the VPN had shell-as-root via the documented default password.
#     If we later need per-operator audit logs / revocation, swap in a
#     bearer token loaded from /etc/voxl-wireguard/config-api.token.
#
# Endpoints:
#   GET  /healthz                  → 200 ok
#   GET  /configs                  → {"configs": [...]} — what's editable
#   GET  /config/<name>            → {"content", "mtime", "size", "parse_error"}
#   PUT  /config/<name>            → write file atomically (with backup)
#                                    body: {"content": "...", "mtime_seen": N}
#                                    409 if mtime_seen mismatches current
#   POST /restart-service/<name>   → systemctl restart on the whitelist
#
# Compatibility: pure stdlib Python 3.6+ so it runs on the qrb5165
# (Ubuntu 18.04) drones as well as newer ones. No pip deps, no venv.
# -----------------------------------------------------------------------------

import json
import os
import re
import shutil
import subprocess
import sys
import tempfile
import threading
import time
from http.server import HTTPServer, BaseHTTPRequestHandler
from socketserver import ThreadingMixIn

# ---- Tunables --------------------------------------------------------------

# Port the service listens on. Chosen in the unassigned IANA user range,
# not used by anything we've seen on VOXL images. Change here if you
# need to and update the dashboard's _DRONE_API_PORT to match.
LISTEN_PORT = 7900

# How long to wait for wg0 to come up before giving up. systemd will
# restart us if we exit; this keeps us from spinning hot.
WG0_WAIT_INTERVAL_S = 2.0
WG0_WAIT_TIMEOUT_S = 60.0

# Hard ceiling on PUT body size — well above any real config file
# (voxl-vtx.conf is ~5KB; 1MB is the sky).
MAX_BODY_BYTES = 1_000_000

# Registry of files editable through this service. To add a new config:
# add an entry here, ensure the file actually exists at that path on
# the target drone images, and confirm the matching systemd unit name.
# Files that don't exist on a given drone are skipped from GET /configs.
DRONE_CONFIGS = {
    "voxl-vtx": {
        "path":        "/etc/modalai/voxl-vtx.conf",
        "service":     "voxl-vtx",
        "description": "VTX (wireless video transmitter)",
        "format":      "json-with-c-comments",
    },
    # Future entries go here. Examples:
    # "voxl-camera-server": {
    #     "path":        "/etc/modalai/voxl-camera-server.conf",
    #     "service":     "voxl-camera-server",
    #     "description": "Camera server",
    #     "format":      "json-with-c-comments",
    # },
}

# Allowed config-name regex (URL-path validation). Independent from
# DRONE_CONFIGS keys so we can reject malformed paths before lookup.
_CONFIG_NAME_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
_SERVICE_NAME_RE = _CONFIG_NAME_RE


def log(msg):
    sys.stderr.write("[voxl-wg-config-api] " + msg + "\n")
    sys.stderr.flush()


# ---- wg0 IP discovery ------------------------------------------------------
#
# Bind only to the WireGuard interface IP so the service is unreachable
# from LTE / WiFi / Ethernet. Reading via `ip addr show wg0` is the
# simplest portable approach — pyroute2 would be cleaner but is not on
# stock VOXL images.

def get_wg0_ip():
    """Return the IPv4 address assigned to wg0, or None if not up yet."""
    try:
        out = subprocess.check_output(
            ["ip", "-4", "addr", "show", "dev", "wg0"],
            stderr=subprocess.DEVNULL, timeout=3,
        ).decode("utf-8", errors="replace")
    except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
        return None
    # Look for "inet 10.8.0.X/Y"
    m = re.search(r"^\s*inet\s+(\d+\.\d+\.\d+\.\d+)/\d+", out, re.MULTILINE)
    return m.group(1) if m else None


def wait_for_wg0(timeout_s=WG0_WAIT_TIMEOUT_S):
    """Block until wg0 has an IP, or return None on timeout."""
    deadline = time.monotonic() + timeout_s
    while time.monotonic() < deadline:
        ip = get_wg0_ip()
        if ip:
            return ip
        time.sleep(WG0_WAIT_INTERVAL_S)
    return None


# ---- Path helpers ----------------------------------------------------------

def _config_entry(name):
    """Look up a config by name; returns the dict or None."""
    if not _CONFIG_NAME_RE.match(name):
        return None
    return DRONE_CONFIGS.get(name)


def _strip_c_comments_for_json(text):
    """Strip /* ... */ and // comments so the body parses as JSON. Used
    only to validate that the operator's edit is well-formed; the comments
    are preserved verbatim on disk."""
    out = []
    i = 0
    n = len(text)
    in_string = False
    in_line_comment = False
    in_block_comment = False
    while i < n:
        c = text[i]
        nxt = text[i+1] if i+1 < n else ""
        if in_line_comment:
            if c == "\n":
                in_line_comment = False
                out.append(c)
            i += 1
            continue
        if in_block_comment:
            if c == "*" and nxt == "/":
                in_block_comment = False
                i += 2
            else:
                if c == "\n":
                    out.append(c)
                i += 1
            continue
        if in_string:
            out.append(c)
            if c == "\\" and i + 1 < n:
                out.append(nxt)
                i += 2
                continue
            if c == '"':
                in_string = False
            i += 1
            continue
        if c == '"':
            in_string = True
            out.append(c)
            i += 1
            continue
        if c == "/" and nxt == "/":
            in_line_comment = True
            i += 2
            continue
        if c == "/" and nxt == "*":
            in_block_comment = True
            i += 2
            continue
        out.append(c)
        i += 1
    return "".join(out)


# ---- HTTP handler ---------------------------------------------------------

class ConfigAPIHandler(BaseHTTPRequestHandler):
    server_version = "voxl-wg-config-api/1"
    # Tighter than default — we don't need the Python version exposed.
    sys_version = ""

    def log_message(self, fmt, *args):
        # Override to write through our prefix so journalctl groups
        # access logs with our other output. Drop the default
        # "127.0.0.1 - - [date]" prefix in favor of a tighter format.
        sys.stderr.write("[voxl-wg-config-api] %s %s\n" %
                         (self.address_string(), fmt % args))

    # ---- responses ----
    def _send_json(self, code, payload):
        body = json.dumps(payload).encode("utf-8")
        try:
            self.send_response(code)
            self.send_header("Content-Type", "application/json")
            self.send_header("Content-Length", str(len(body)))
            self.send_header("Cache-Control", "no-store")
            self.end_headers()
            self.wfile.write(body)
        except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
            pass

    def _send_text(self, code, text, ctype="text/plain"):
        body = text.encode("utf-8") if isinstance(text, str) else text
        try:
            self.send_response(code)
            self.send_header("Content-Type", ctype)
            self.send_header("Content-Length", str(len(body)))
            self.end_headers()
            self.wfile.write(body)
        except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
            pass

    # ---- routing ----
    def do_GET(self):
        path = self.path.split("?", 1)[0]
        if path == "/healthz":
            return self._send_text(200, "ok\n")
        if path == "/configs":
            return self._handle_list_configs()
        if path.startswith("/config/"):
            name = path[len("/config/"):]
            return self._handle_read_config(name)
        return self._send_json(404, {"error": "not found"})

    def do_PUT(self):
        path = self.path.split("?", 1)[0]
        if path.startswith("/config/"):
            name = path[len("/config/"):]
            return self._handle_write_config(name)
        return self._send_json(405, {"error": "PUT not allowed here"})

    def do_POST(self):
        path = self.path.split("?", 1)[0]
        if path.startswith("/restart-service/"):
            name = path[len("/restart-service/"):]
            return self._handle_restart_service(name)
        return self._send_json(405, {"error": "POST not allowed here"})

    # ---- /configs ----
    def _handle_list_configs(self):
        """Return the registry filtered to entries whose file actually
        exists on this drone. Lets a single registry serve drones with
        different installed packages (no voxl-vtx → no voxl-vtx entry)."""
        out = []
        for name, info in DRONE_CONFIGS.items():
            if not os.path.isfile(info["path"]):
                continue
            out.append({
                "name":        name,
                "path":        info["path"],
                "service":     info.get("service"),
                "description": info["description"],
                "format":      info.get("format"),
            })
        return self._send_json(200, {"configs": out})

    # ---- /config/<name> GET ----
    def _handle_read_config(self, name):
        info = _config_entry(name)
        if not info:
            return self._send_json(404, {"error": "unknown config '%s'" % name})
        path = info["path"]
        try:
            st = os.stat(path)
        except FileNotFoundError:
            return self._send_json(404, {"error": "file not present: %s" % path})
        except PermissionError as e:
            return self._send_json(500, {"error": "permission denied reading %s: %s" % (path, e)})
        try:
            with open(path, "r", encoding="utf-8", errors="replace") as f:
                content = f.read()
        except (OSError, PermissionError) as e:
            return self._send_json(500, {"error": "read failed: %s" % e})
        # Try to detect malformed-on-disk JSON so the dashboard can warn.
        parse_error = None
        try:
            json.loads(_strip_c_comments_for_json(content))
        except Exception as e:
            parse_error = str(e)
        return self._send_json(200, {
            "name":         name,
            "path":         path,
            "service":      info.get("service"),
            "description":  info["description"],
            "format":       info.get("format"),
            "content":      content,
            # Send mtime as integer seconds — matches what the dashboard
            # previously got from `stat -c '%Y'` over SSH, so the
            # client-side mtime_seen check stays compatible.
            "mtime":        int(st.st_mtime),
            "size":         st.st_size,
            "parse_error":  parse_error,
        })

    # ---- /config/<name> PUT ----
    def _handle_write_config(self, name):
        info = _config_entry(name)
        if not info:
            return self._send_json(404, {"error": "unknown config '%s'" % name})
        path = info["path"]

        # Read body
        try:
            length = int(self.headers.get("Content-Length") or 0)
        except ValueError:
            length = 0
        if length <= 0 or length > MAX_BODY_BYTES:
            return self._send_json(400, {"error": "missing or oversized body"})
        try:
            raw = self.rfile.read(length)
            payload = json.loads(raw.decode("utf-8"))
        except (ValueError, UnicodeDecodeError) as e:
            return self._send_json(400, {"error": "bad request body: %s" % e})

        content = payload.get("content")
        mtime_seen = payload.get("mtime_seen")
        if not isinstance(content, str):
            return self._send_json(400, {"error": "content must be a string"})

        # Validate the body parses as JSON (after comment-stripping). We
        # don't write a file that voxl-vtx (or whatever consumer) will
        # refuse to load on startup.
        try:
            json.loads(_strip_c_comments_for_json(content))
        except Exception as e:
            return self._send_json(400, {"error": "content is not valid JSON: %s" % e})

        # Concurrent-edit detection. mtime_seen == 0 or None means "force",
        # caller explicitly wants to overwrite whatever's there.
        if isinstance(mtime_seen, (int, float)) and mtime_seen > 0:
            try:
                current_mtime = int(os.stat(path).st_mtime)
                if current_mtime > int(mtime_seen):
                    return self._send_json(409, {
                        "error":          "file was modified on disk since you loaded it",
                        "current_mtime":  current_mtime,
                        "your_mtime":     int(mtime_seen),
                    })
            except FileNotFoundError:
                # File got deleted between read and write. Treat as a
                # conflict — operator should reload and re-decide.
                return self._send_json(409, {"error": "file no longer exists on disk"})

        # Atomic write: backup, write to .tmp, atomic rename. The
        # NamedTemporaryFile(dir=) ensures rename stays on the same
        # filesystem so os.replace is atomic.
        target_dir = os.path.dirname(path) or "/"
        backup_path = path + ".bak.dashboard"
        try:
            if os.path.isfile(path):
                shutil.copy2(path, backup_path)
            fd, tmp_path = tempfile.mkstemp(prefix=".voxl-wg-tmp-", dir=target_dir)
            try:
                with os.fdopen(fd, "w", encoding="utf-8") as f:
                    f.write(content)
                # Preserve permissions from the original file if possible.
                if os.path.isfile(backup_path):
                    try:
                        st = os.stat(backup_path)
                        os.chmod(tmp_path, st.st_mode & 0o777)
                        try:
                            os.chown(tmp_path, st.st_uid, st.st_gid)
                        except PermissionError:
                            pass  # not running as root or no permission
                    except OSError:
                        pass
                os.replace(tmp_path, path)
                tmp_path = None  # consumed by replace
            finally:
                if tmp_path and os.path.exists(tmp_path):
                    try: os.unlink(tmp_path)
                    except OSError: pass
        except (OSError, PermissionError) as e:
            return self._send_json(500, {"error": "write failed: %s" % e})

        try:
            new_mtime = int(os.stat(path).st_mtime)
        except OSError:
            new_mtime = 0
        return self._send_json(200, {
            "ok":          True,
            "new_mtime":   new_mtime,
            "backup_path": backup_path,
        })

    # ---- /restart-service/<name> POST ----
    def _handle_restart_service(self, name):
        """Restart a systemd unit, validated against the registry's
        service whitelist. Operators can only restart units that this
        service knows about — not arbitrary units."""
        if not _SERVICE_NAME_RE.match(name):
            return self._send_json(400, {"error": "invalid service name"})
        legal = {info["service"] for info in DRONE_CONFIGS.values() if info.get("service")}
        if name not in legal:
            return self._send_json(400, {"error": "service '%s' is not in the restart whitelist" % name})
        try:
            subprocess.run(["systemctl", "restart", name],
                           check=True, capture_output=True, timeout=15)
        except subprocess.CalledProcessError as e:
            return self._send_json(502, {
                "error":  "restart failed",
                "stderr": (e.stderr or b"").decode("utf-8", errors="replace").strip(),
            })
        except subprocess.TimeoutExpired:
            return self._send_json(502, {"error": "systemctl timed out"})
        # Report whether the unit is now active. is-active prints
        # "active" / "inactive" / "failed" / etc. and uses exit code 0
        # only for active.
        try:
            ret = subprocess.run(["systemctl", "is-active", name],
                                 capture_output=True, timeout=5)
            active = (ret.stdout or b"").decode().strip()
        except Exception:
            active = "unknown"
        return self._send_json(200, {"ok": True, "active": active})


class _ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
    daemon_threads = True
    allow_reuse_address = True

    def handle_error(self, request, client_address):
        # Suppress the stdlib stack-trace dump for routine client
        # disconnects. See the same logic in voxl-wg-server-web.
        exc = sys.exc_info()[0]
        if exc in (ConnectionResetError, ConnectionAbortedError,
                   BrokenPipeError, TimeoutError):
            return
        super().handle_error(request, client_address)


# ---- entry point ----------------------------------------------------------

def main():
    if os.geteuid() != 0:
        log("ERROR: must run as root (needs to read/write /etc/modalai)")
        return 1

    log("starting; waiting for wg0 interface...")
    bind_ip = wait_for_wg0()
    if not bind_ip:
        log("ERROR: wg0 did not come up within %ds; exiting (systemd will retry)" % WG0_WAIT_TIMEOUT_S)
        return 2
    log("wg0 IP: %s; binding to %s:%d" % (bind_ip, bind_ip, LISTEN_PORT))

    try:
        srv = _ThreadedHTTPServer((bind_ip, LISTEN_PORT), ConfigAPIHandler)
    except OSError as e:
        log("ERROR: bind failed: %s" % e)
        return 3
    log("listening; %d config(s) registered" % len(DRONE_CONFIGS))
    try:
        srv.serve_forever()
    except KeyboardInterrupt:
        log("shutting down (SIGINT)")
    finally:
        try: srv.server_close()
        except Exception: pass
    return 0


if __name__ == "__main__":
    sys.exit(main())
