Files
proxmox-task/proxmox_task_runner.py

215 lines
8.3 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
Proxmox Task Runner
Monitors a directory for command files, starts a VM via Proxmox API,
executes a task inside the VM via SSH, reports progress, and shuts down the VM.
"""
import os
import time
import json
import logging
import signal
import sys
from pathlib import Path
from typing import Dict, Any
import requests
import paramiko
# ---------------------------- Configuration ----------------------------
# These should be set via environment variables or a .env file
PROXMOX_HOST = os.getenv("PROXMOX_HOST", "https://proxmox.example.com:8006")
PROXMOX_TOKEN_ID = os.getenv("PROXMOX_TOKEN_ID", "automation@pam!token")
PROXMOX_TOKEN_SECRET = os.getenv("PROXMOX_TOKEN_SECRET", "your-token-secret")
VERIFY_SSL = os.getenv("PROXMOX_VERIFY_SSL", "false").lower() == "true"
# SSH settings for VM access (assumes VM has SSH server and key auth)
SSH_USERNAME = os.getenv("VM_SSH_USER", "root")
SSH_KEY_PATH = os.getenv("VM_SSH_KEY", "/root/.ssh/id_rsa")
SSH_PORT = int(os.getenv("VM_SSH_PORT", "22"))
# Directory to watch for command files (each file is a JSON task)
COMMAND_DIR = Path(os.getenv("COMMAND_DIR", "/var/task-commands"))
PROCESSED_DIR = COMMAND_DIR / "processed"
FAILED_DIR = COMMAND_DIR / "failed"
# Polling interval (seconds)
POLL_INTERVAL = int(os.getenv("POLL_INTERVAL", "5"))
# Logging setup
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler("/var/log/proxmox_task_runner.log"),
],
)
logger = logging.getLogger(__name__)
# ---------------------------- Helper Functions ----------------------------
def get_proxmox_auth_headers() -> Dict[str, str]:
"""Return headers for Proxmox API token authentication."""
return {
"Authorization": f"PVEAPIToken={PROXMOX_TOKEN_ID}={PROXMOX_TOKEN_SECRET}",
"Content-Type": "application/json",
}
def start_vm(node: str, vmid: int) -> bool:
"""Start a VM via Proxmox API."""
url = f"{PROXMOX_HOST}/api2/json/nodes/{node}/qemu/{vmid}/status/start"
try:
resp = requests.post(url, headers=get_proxmox_auth_headers(), verify=VERIFY_SSL, timeout=10)
resp.raise_for_status()
logger.info(f"Started VM {vmid} on node {node}")
return True
except Exception as e:
logger.error(f"Failed to start VM {vmid}: {e}")
return False
def stop_vm(node: str, vmid: int) -> bool:
"""Stop a VM via Proxmox API."""
url = f"{PROXMOX_HOST}/api2/json/nodes/{node}/qemu/{vmid}/status/shutdown"
try:
resp = requests.post(url, headers=get_proxmox_auth_headers(), verify=VERIFY_SSL, timeout=10)
resp.raise_for_status()
logger.info(f"Shutdown request sent for VM {vmid} on node {node}")
return True
except Exception as e:
logger.error(f"Failed to stop VM {vmid}: {e}")
return False
def wait_for_vm_status(node: str, vmid: int, target_status: str, timeout: int = 120) -> bool:
"""Poll VM status until it reaches target_status or timeout."""
url = f"{PROXMOX_HOST}/api2/json/nodes/{node}/qemu/{vmid}/status/current"
start = time.time()
while time.time() - start < timeout:
try:
resp = requests.get(url, headers=get_proxmox_auth_headers(), verify=VERIFY_SSL, timeout=5)
resp.raise_for_status()
data = resp.json()
status = data.get("data", {}).get("status")
if status == target_status:
logger.info(f"VM {vmid} status is {target_status}")
return True
time.sleep(5)
except Exception as e:
logger.warning(f"Error checking VM status: {e}")
time.sleep(5)
logger.error(f"Timeout waiting for VM {vmid} to reach {target_status}")
return False
def run_task_via_ssh(vm_ip: str, task_command: str) -> Dict[str, Any]:
"""Execute a command inside the VM via SSH and return result."""
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try:
ssh.connect(hostname=vm_ip, port=SSH_PORT, username=SSH_USERNAME, key_filename=SSH_KEY_PATH, timeout=10)
logger.info(f"SSH connected to {vm_ip}")
stdin, stdout, stderr = ssh.exec_command(task_command, get_pty=True)
# Stream output for progress reporting
progress_lines = []
while not stdout.channel.exit_status_ready():
if stdout.channel.recv_ready():
line = stdout.channel.recv(1024).decode("utf-8", errors="ignore").rstrip()
if line:
logger.info(f"[VM] {line}")
progress_lines.append(line)
time.sleep(0.5)
exit_status = stdout.channel.recv_exit_status()
stdout_content = stdout.read().decode("utf-8", errors="ignore")
stderr_content = stderr.read().decode("utf-8", errors="ignore")
result = {
"exit_status": exit_status,
"stdout": stdout_content,
"stderr": stderr_content,
"progress": progress_lines,
}
if exit_status == 0:
logger.info(f"Task completed successfully on {vm_ip}")
else:
logger.error(f"Task failed on {vm_ip} with exit code {exit_status}")
return result
except Exception as e:
logger.error(f"SSH error: {e}")
return {"exit_status": 1, "error": str(e)}
finally:
ssh.close()
def process_command_file(filepath: Path):
"""Read a command file, execute the task, and move file to appropriate directory."""
try:
with open(filepath, "r") as f:
command_data = json.load(f)
# Expected keys: vm_id, node, vm_ip (or we can fetch via API), task
vmid = int(command_data.get("vm_id"))
node = command_data.get("node")
vm_ip = command_data.get("vm_ip") # Ideally, we would resolve via API; assume provided
task = command_data.get("task")
if not all([vmid, node, vm_ip, task]):
raise ValueError("Missing required fields in command file")
logger.info(f"Processing command for VM {vmid}: {task}")
# 1. Start VM
if not start_vm(node, vmid):
raise RuntimeError("Failed to start VM")
# 2. Wait for VM to be running
if not wait_for_vm_status(node, vmid, "running"):
raise RuntimeError("VM did not start in time")
# 3. Run task via SSH
result = run_task_via_ssh(vm_ip, task)
success = result.get("exit_status") == 0
# 4. Stop VM
stop_vm(node, vmid)
wait_for_vm_status(node, vmid, "stopped")
# 5. Determine outcome
if success:
logger.info(f"Task succeeded for VM {vmid}")
dest_dir = PROCESSED_DIR
else:
logger.error(f"Task failed for VM {vmid}")
dest_dir = FAILED_DIR
# Move file
dest_dir.mkdir(parents=True, exist_ok=True)
dest_path = dest_dir / filepath.name
filepath.rename(dest_path)
logger.info(f"Moved command file to {dest_path}")
except Exception as e:
logger.exception(f"Error processing {filepath}: {e}")
# Move to failed directory for inspection
FAILED_DIR.mkdir(parents=True, exist_ok=True)
dest_path = FAILED_DIR / filepath.name
try:
filepath.rename(dest_path)
except Exception:
pass
def main():
"""Main loop: watch command directory and process files."""
COMMAND_DIR.mkdir(parents=True, exist_ok=True)
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
FAILED_DIR.mkdir(parents=True, exist_ok=True)
logger.info("Starting Proxmox Task Runner")
logger.info(f"Watching directory: {COMMAND_DIR}")
def signal_handler(sig, frame):
logger.info("Shutting down...")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
while True:
try:
# List new command files (ignore processed/failed subdirs)
files = [f for f in COMMAND_DIR.iterdir() if f.is_file() and f.suffix == ".json"]
for f in files:
process_command_file(f)
time.sleep(POLL_INTERVAL)
except Exception as e:
logger.exception(f"Unexpected error in main loop: {e}")
time.sleep(POLL_INTERVAL)
if __name__ == "__main__":
main()