#!/usr/bin/env python3
# -----------------------------------------------------------------------------
# voxl-wg-config-api — drone-side HTTP service for safe remote config edits.
#
# This service exists so the dashboard can read and write a small,
# whitelisted set of /etc/modalai/* config files (and restart their
# matching systemd units) without needing SSH access to the drone.
#
# Why a REST service instead of SSH:
#   - No password handling on the dashboard side.
#   - No SSH handshake overhead (one HTTP round-trip per op instead of
#     ~1-2s of SSH negotiation).
#   - Narrower attack surface: this service can only touch files in the
#     whitelist and only restart units in the whitelist. Shell access
#     gives you the whole drone.
#   - Easier to reason about: explicit JSON in/out, explicit status codes,
#     no shell quoting concerns.
#
# Security model:
#   - Bound to the wg0 IP only. Not exposed on the drone's LTE / WiFi /
#     ethernet interfaces. If you're on the WireGuard VPN, you can talk
#     to this; otherwise the port doesn't exist for you.
#   - No application-level auth in v1 — VPN reachability IS the auth.
#     Same security stance as the previous SSH-based approach: anyone on
#     the VPN had shell-as-root via the documented default password.
#     If we later need per-operator audit logs / revocation, swap in a
#     bearer token loaded from /etc/voxl-wireguard/config-api.token.
#
# Endpoints:
#   GET  /healthz                  → 200 ok
#   GET  /configs                  → {"configs": [...]} — what's editable
#   GET  /config/<name>            → {"content", "mtime", "size", "parse_error"}
#   PUT  /config/<name>            → write file atomically (with backup)
#                                    body: {"content": "...", "mtime_seen": N}
#                                    409 if mtime_seen mismatches current
#   POST /restart-service/<name>   → systemctl restart on the whitelist
#
# Compatibility: pure stdlib Python 3.6+ so it runs on the qrb5165
# (Ubuntu 18.04) drones as well as newer ones. No pip deps, no venv.
# -----------------------------------------------------------------------------

import json
import os
import re
import shutil
import socket
import subprocess
import sys
import tempfile
import threading
import time
from http.server import HTTPServer, BaseHTTPRequestHandler
from socketserver import ThreadingMixIn

# ---- Tunables --------------------------------------------------------------

# Port the service listens on. Chosen in the unassigned IANA user range,
# not used by anything we've seen on VOXL images. Change here if you
# need to and update the dashboard's _DRONE_API_PORT to match.
LISTEN_PORT = 7900

# How long to wait for wg0 to come up before giving up. systemd will
# restart us if we exit; this keeps us from spinning hot.
WG0_WAIT_INTERVAL_S = 2.0
WG0_WAIT_TIMEOUT_S = 60.0

# Hard ceiling on PUT body size — well above any real config file
# (voxl-vtx.conf is ~5KB; 1MB is the sky).
MAX_BODY_BYTES = 1_000_000

# Registry of files editable through this service. To add a new config:
# add an entry here, ensure the file actually exists at that path on
# the target drone images, and confirm the matching systemd unit name.
# Files that don't exist on a given drone are skipped from GET /configs.
DRONE_CONFIGS = {
    "voxl-vtx": {
        "path":        "/etc/modalai/voxl-vtx.conf",
        "service":     "voxl-vtx",
        "description": "VTX (wireless video transmitter)",
        "format":      "json-with-c-comments",
    },
    # Future entries go here. Examples:
    # "voxl-camera-server": {
    #     "path":        "/etc/modalai/voxl-camera-server.conf",
    #     "service":     "voxl-camera-server",
    #     "description": "Camera server",
    #     "format":      "json-with-c-comments",
    # },
}

# Allowed config-name regex (URL-path validation). Independent from
# DRONE_CONFIGS keys so we can reject malformed paths before lookup.
_CONFIG_NAME_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
_SERVICE_NAME_RE = _CONFIG_NAME_RE


def log(msg):
    sys.stderr.write("[voxl-wg-config-api] " + msg + "\n")
    sys.stderr.flush()


# ---- wg0 IP discovery ------------------------------------------------------
#
# Bind only to the WireGuard interface IP so the service is unreachable
# from LTE / WiFi / Ethernet. Reading via `ip addr show wg0` is the
# simplest portable approach — pyroute2 would be cleaner but is not on
# stock VOXL images.

def get_wg0_ip():
    """Return the IPv4 address assigned to wg0, or None if not up yet."""
    try:
        out = subprocess.check_output(
            ["ip", "-4", "addr", "show", "dev", "wg0"],
            stderr=subprocess.DEVNULL, timeout=3,
        ).decode("utf-8", errors="replace")
    except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
        return None
    # Look for "inet 10.8.0.X/Y"
    m = re.search(r"^\s*inet\s+(\d+\.\d+\.\d+\.\d+)/\d+", out, re.MULTILINE)
    return m.group(1) if m else None


def wait_for_wg0(timeout_s=WG0_WAIT_TIMEOUT_S):
    """Block until wg0 has an IP, or return None on timeout."""
    deadline = time.monotonic() + timeout_s
    while time.monotonic() < deadline:
        ip = get_wg0_ip()
        if ip:
            return ip
        time.sleep(WG0_WAIT_INTERVAL_S)
    return None


# ---- Path helpers ----------------------------------------------------------

def _config_entry(name):
    """Look up a config by name; returns the dict or None."""
    if not _CONFIG_NAME_RE.match(name):
        return None
    return DRONE_CONFIGS.get(name)


def _strip_c_comments_for_json(text):
    """Strip /* ... */ and // comments so the body parses as JSON. Used
    only to validate that the operator's edit is well-formed; the comments
    are preserved verbatim on disk."""
    out = []
    i = 0
    n = len(text)
    in_string = False
    in_line_comment = False
    in_block_comment = False
    while i < n:
        c = text[i]
        nxt = text[i+1] if i+1 < n else ""
        if in_line_comment:
            if c == "\n":
                in_line_comment = False
                out.append(c)
            i += 1
            continue
        if in_block_comment:
            if c == "*" and nxt == "/":
                in_block_comment = False
                i += 2
            else:
                if c == "\n":
                    out.append(c)
                i += 1
            continue
        if in_string:
            out.append(c)
            if c == "\\" and i + 1 < n:
                out.append(nxt)
                i += 2
                continue
            if c == '"':
                in_string = False
            i += 1
            continue
        if c == '"':
            in_string = True
            out.append(c)
            i += 1
            continue
        if c == "/" and nxt == "/":
            in_line_comment = True
            i += 2
            continue
        if c == "/" and nxt == "*":
            in_block_comment = True
            i += 2
            continue
        out.append(c)
        i += 1
    return "".join(out)


# ---- HTTP handler ---------------------------------------------------------

class ConfigAPIHandler(BaseHTTPRequestHandler):
    server_version = "voxl-wg-config-api/1"
    # Tighter than default — we don't need the Python version exposed.
    sys_version = ""

    def log_message(self, fmt, *args):
        # Override to write through our prefix so journalctl groups
        # access logs with our other output. Drop the default
        # "127.0.0.1 - - [date]" prefix in favor of a tighter format.
        sys.stderr.write("[voxl-wg-config-api] %s %s\n" %
                         (self.address_string(), fmt % args))

    # ---- responses ----
    def _send_json(self, code, payload):
        body = json.dumps(payload).encode("utf-8")
        try:
            self.send_response(code)
            self.send_header("Content-Type", "application/json")
            self.send_header("Content-Length", str(len(body)))
            self.send_header("Cache-Control", "no-store")
            self.end_headers()
            self.wfile.write(body)
        except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
            pass

    def _send_text(self, code, text, ctype="text/plain"):
        body = text.encode("utf-8") if isinstance(text, str) else text
        try:
            self.send_response(code)
            self.send_header("Content-Type", ctype)
            self.send_header("Content-Length", str(len(body)))
            self.end_headers()
            self.wfile.write(body)
        except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
            pass

    # ---- routing ----
    def do_GET(self):
        path = self.path.split("?", 1)[0]
        if path == "/healthz":
            return self._send_text(200, "ok\n")
        if path == "/configs":
            return self._handle_list_configs()
        if path.startswith("/config/"):
            name = path[len("/config/"):]
            return self._handle_read_config(name)
        return self._send_json(404, {"error": "not found"})

    def do_PUT(self):
        path = self.path.split("?", 1)[0]
        if path.startswith("/config/"):
            name = path[len("/config/"):]
            return self._handle_write_config(name)
        return self._send_json(405, {"error": "PUT not allowed here"})

    def do_POST(self):
        path = self.path.split("?", 1)[0]
        if path.startswith("/restart-service/"):
            name = path[len("/restart-service/"):]
            return self._handle_restart_service(name)
        return self._send_json(405, {"error": "POST not allowed here"})

    # ---- /configs ----
    def _handle_list_configs(self):
        """Return the registry filtered to entries whose file actually
        exists on this drone. Lets a single registry serve drones with
        different installed packages (no voxl-vtx → no voxl-vtx entry)."""
        out = []
        for name, info in DRONE_CONFIGS.items():
            if not os.path.isfile(info["path"]):
                continue
            out.append({
                "name":        name,
                "path":        info["path"],
                "service":     info.get("service"),
                "description": info["description"],
                "format":      info.get("format"),
            })
        return self._send_json(200, {"configs": out})

    # ---- /config/<name> GET ----
    def _handle_read_config(self, name):
        info = _config_entry(name)
        if not info:
            return self._send_json(404, {"error": "unknown config '%s'" % name})
        path = info["path"]
        try:
            st = os.stat(path)
        except FileNotFoundError:
            return self._send_json(404, {"error": "file not present: %s" % path})
        except PermissionError as e:
            return self._send_json(500, {"error": "permission denied reading %s: %s" % (path, e)})
        try:
            with open(path, "r", encoding="utf-8", errors="replace") as f:
                content = f.read()
        except (OSError, PermissionError) as e:
            return self._send_json(500, {"error": "read failed: %s" % e})
        # Try to detect malformed-on-disk JSON so the dashboard can warn.
        parse_error = None
        try:
            json.loads(_strip_c_comments_for_json(content))
        except Exception as e:
            parse_error = str(e)
        return self._send_json(200, {
            "name":         name,
            "path":         path,
            "service":      info.get("service"),
            "description":  info["description"],
            "format":       info.get("format"),
            "content":      content,
            # Send mtime as integer seconds — matches what the dashboard
            # previously got from `stat -c '%Y'` over SSH, so the
            # client-side mtime_seen check stays compatible.
            "mtime":        int(st.st_mtime),
            "size":         st.st_size,
            "parse_error":  parse_error,
        })

    # ---- /config/<name> PUT ----
    def _handle_write_config(self, name):
        info = _config_entry(name)
        if not info:
            return self._send_json(404, {"error": "unknown config '%s'" % name})
        path = info["path"]

        # Read body
        try:
            length = int(self.headers.get("Content-Length") or 0)
        except ValueError:
            length = 0
        if length <= 0 or length > MAX_BODY_BYTES:
            return self._send_json(400, {"error": "missing or oversized body"})
        try:
            raw = self.rfile.read(length)
            payload = json.loads(raw.decode("utf-8"))
        except (ValueError, UnicodeDecodeError) as e:
            return self._send_json(400, {"error": "bad request body: %s" % e})

        content = payload.get("content")
        mtime_seen = payload.get("mtime_seen")
        if not isinstance(content, str):
            return self._send_json(400, {"error": "content must be a string"})

        # Validate the body parses as JSON (after comment-stripping). We
        # don't write a file that voxl-vtx (or whatever consumer) will
        # refuse to load on startup.
        try:
            json.loads(_strip_c_comments_for_json(content))
        except Exception as e:
            return self._send_json(400, {"error": "content is not valid JSON: %s" % e})

        # Concurrent-edit detection. mtime_seen == 0 or None means "force",
        # caller explicitly wants to overwrite whatever's there.
        if isinstance(mtime_seen, (int, float)) and mtime_seen > 0:
            try:
                current_mtime = int(os.stat(path).st_mtime)
                if current_mtime > int(mtime_seen):
                    return self._send_json(409, {
                        "error":          "file was modified on disk since you loaded it",
                        "current_mtime":  current_mtime,
                        "your_mtime":     int(mtime_seen),
                    })
            except FileNotFoundError:
                # File got deleted between read and write. Treat as a
                # conflict — operator should reload and re-decide.
                return self._send_json(409, {"error": "file no longer exists on disk"})

        # Atomic write: backup, write to .tmp, atomic rename. The
        # NamedTemporaryFile(dir=) ensures rename stays on the same
        # filesystem so os.replace is atomic.
        target_dir = os.path.dirname(path) or "/"
        backup_path = path + ".bak.dashboard"
        try:
            if os.path.isfile(path):
                shutil.copy2(path, backup_path)
            fd, tmp_path = tempfile.mkstemp(prefix=".voxl-wg-tmp-", dir=target_dir)
            try:
                with os.fdopen(fd, "w", encoding="utf-8") as f:
                    f.write(content)
                # Preserve permissions from the original file if possible.
                if os.path.isfile(backup_path):
                    try:
                        st = os.stat(backup_path)
                        os.chmod(tmp_path, st.st_mode & 0o777)
                        try:
                            os.chown(tmp_path, st.st_uid, st.st_gid)
                        except PermissionError:
                            pass  # not running as root or no permission
                    except OSError:
                        pass
                os.replace(tmp_path, path)
                tmp_path = None  # consumed by replace
            finally:
                if tmp_path and os.path.exists(tmp_path):
                    try: os.unlink(tmp_path)
                    except OSError: pass
        except (OSError, PermissionError) as e:
            return self._send_json(500, {"error": "write failed: %s" % e})

        try:
            new_mtime = int(os.stat(path).st_mtime)
        except OSError:
            new_mtime = 0
        return self._send_json(200, {
            "ok":          True,
            "new_mtime":   new_mtime,
            "backup_path": backup_path,
        })

    # ---- /restart-service/<name> POST ----
    def _handle_restart_service(self, name):
        """Restart a systemd unit, validated against the registry's
        service whitelist. Operators can only restart units that this
        service knows about — not arbitrary units."""
        if not _SERVICE_NAME_RE.match(name):
            return self._send_json(400, {"error": "invalid service name"})
        legal = {info["service"] for info in DRONE_CONFIGS.values() if info.get("service")}
        if name not in legal:
            return self._send_json(400, {"error": "service '%s' is not in the restart whitelist" % name})
        try:
            subprocess.run(["systemctl", "restart", name],
                           check=True, capture_output=True, timeout=15)
        except subprocess.CalledProcessError as e:
            return self._send_json(502, {
                "error":  "restart failed",
                "stderr": (e.stderr or b"").decode("utf-8", errors="replace").strip(),
            })
        except subprocess.TimeoutExpired:
            return self._send_json(502, {"error": "systemctl timed out"})
        # Report whether the unit is now active. is-active prints
        # "active" / "inactive" / "failed" / etc. and uses exit code 0
        # only for active.
        try:
            ret = subprocess.run(["systemctl", "is-active", name],
                                 capture_output=True, timeout=5)
            active = (ret.stdout or b"").decode().strip()
        except Exception:
            active = "unknown"
        return self._send_json(200, {"ok": True, "active": active})


class _ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
    daemon_threads = True
    allow_reuse_address = True

    def server_bind(self):
        """Override to clamp the TCP maximum segment size BEFORE listen.
        The drone's wg0 has MTU 1280, so each TCP segment can hold at
        most 1240 bytes of payload (1280 minus the 40-byte TCP+IP
        header). The Linux kernel should figure this out from the route
        and advertise MSS=1240 in the SYN-ACK, then never send segments
        larger than that. But in practice we've seen the response body
        get dropped silently while the headers arrive fine — the
        signature of segments-too-large on a path where PMTU discovery
        isn't propagating through WireGuard + LTE NAT cleanly.

        Setting TCP_MAXSEG to 1200 forces an explicit clamp. It's a
        ceiling, so the kernel will use min(1200, route_PMTU) which
        keeps us safe even if the route MTU changes. 1200 leaves an
        80-byte margin under 1280 for any extra encapsulation overhead
        we don't account for.

        Must be set on the LISTENING socket before any accept(); accept
        inherits it on each connection. Set after super().server_bind()
        (which calls bind()) but before server_activate() (which calls
        listen()) — that ordering is what HTTPServer.__init__ does
        between server_bind and server_activate.
        """
        super().server_bind()
        try:
            self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_MAXSEG, 1200)
            # Confirm it took. Linux will silently lower the value if it's
            # too high for the route, but won't raise. Read it back.
            actual = self.socket.getsockopt(socket.IPPROTO_TCP, socket.TCP_MAXSEG)
            log("TCP_MAXSEG set to %d on listening socket" % actual)
        except (OSError, AttributeError) as e:
            # AttributeError on platforms missing TCP_MAXSEG in the socket
            # module (none we care about, but safe to fall back).
            log("warning: could not set TCP_MAXSEG: %s" % e)

    def handle_error(self, request, client_address):
        # Suppress the stdlib stack-trace dump for routine client
        # disconnects. See the same logic in voxl-wg-server-web.
        exc = sys.exc_info()[0]
        if exc in (ConnectionResetError, ConnectionAbortedError,
                   BrokenPipeError, TimeoutError):
            return
        super().handle_error(request, client_address)


# ---- entry point ----------------------------------------------------------

def main():
    if os.geteuid() != 0:
        log("ERROR: must run as root (needs to read/write /etc/modalai)")
        return 1

    log("starting; waiting for wg0 interface...")
    bind_ip = wait_for_wg0()
    if not bind_ip:
        log("ERROR: wg0 did not come up within %ds; exiting (systemd will retry)" % WG0_WAIT_TIMEOUT_S)
        return 2
    log("wg0 IP: %s; binding to %s:%d" % (bind_ip, bind_ip, LISTEN_PORT))

    try:
        srv = _ThreadedHTTPServer((bind_ip, LISTEN_PORT), ConfigAPIHandler)
    except OSError as e:
        log("ERROR: bind failed: %s" % e)
        return 3
    log("listening; %d config(s) registered" % len(DRONE_CONFIGS))
    try:
        srv.serve_forever()
    except KeyboardInterrupt:
        log("shutting down (SIGINT)")
    finally:
        try: srv.server_close()
        except Exception: pass
    return 0


if __name__ == "__main__":
    sys.exit(main())
