Files
2026-06-05 15:52:42 -07:00

483 lines
22 KiB
Python

"""Local-only advisory context bundle planner for Atlas/Hermes.
This module intentionally emits a retrieval/authority plan only. It does not call
Hermes memory/session/RAG/web tools, mutate vector stores, broaden private roots,
or change live routing.
"""
from __future__ import annotations
import json
import ipaddress
import re
import time
import urllib.error
import urllib.parse
import urllib.request
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Mapping, Sequence
SCHEMA = "atlas_context_gate_plan_v1"
NPU_BUSY_PATH = Path("/sys/class/accel/accel0/device/npu_busy_time_us")
DEFAULT_CLASSIFIER_URL = "http://127.0.0.1:18819/v1/classify"
AUTHORITY = {
"may_route": False,
"may_write_memory": False,
"may_send_external": False,
"may_process_private_dirs": False,
"may_execute_tools": False,
"may_restart_services": False,
"may_mutate_vector_db": False,
"may_change_live_config": False,
}
GATES = {
"live_routing_change": "closed_requires_explicit_approval",
"memory_write": "closed_requires_explicit_approval",
"outbound_send": "closed_requires_explicit_approval",
"tool_execution": "closed_requires_explicit_approval",
"service_restart": "closed_requires_explicit_approval",
"vector_mutation": "closed_requires_explicit_approval",
"private_root_broadening": "closed_requires_explicit_approval",
}
_ALLOWED_SOURCES = {
"durable_memory",
"session_search",
"rag_search",
"repo_files",
"live_system",
"web",
"no_retrieval",
}
class ContextGateError(ValueError):
"""Raised for invalid requests or unavailable required local stages."""
@dataclass(frozen=True)
class ClassifierResult:
labels: Mapping[str, Any]
npu_busy_delta_us: int | None
sysfs_npu_busy_delta_us: int | None
outer_sysfs_delta_us: int | None
live: bool
warning: str | None = None
def read_npu_busy_time_us(path: Path = NPU_BUSY_PATH) -> int | None:
try:
return int(path.read_text(encoding="utf-8").strip())
except (FileNotFoundError, PermissionError, ValueError, OSError):
return None
def _label_value(labels: Mapping[str, Any], name: str, default: Any) -> Any:
value = labels.get(name, default)
if isinstance(value, Mapping) and "value" in value:
return value.get("value", default)
return value
def _label_confidence(labels: Mapping[str, Any], name: str, default: float = 0.5) -> float:
value = labels.get(name)
if isinstance(value, Mapping):
try:
return float(value.get("confidence", default))
except (TypeError, ValueError):
return default
return default
def heuristic_labels(query: str, context: Mapping[str, Any] | None = None) -> dict[str, Any]:
"""Small transparent fallback used by tests and explicit offline smoke mode."""
text = query.lower()
platform = str((context or {}).get("platform", "unknown")).lower()
current_words = ["current", "now", "health", "port", "process", "systemd", "status", "npu", "listening", "logs"]
prior_words = ["where did we leave", "what did we decide", "previous", "earlier", "handoff", "plan"]
coding_words = ["implement", "code", "repo", "test", "pytest", "diff", "branch", "hermes"]
research_words = ["research", "compare", "summarize", "explain", "what is", "how do i"]
unsafe_words = ["change live routing", "live routing", "restart", "send", "write memory", "reindex", "mutate", "delete"]
safety = any(w in text for w in unsafe_words)
tool_needed = any(w in text for w in current_words + coding_words) or safety
if platform == "kanban" or "kanban" in text or any(w in text for w in coding_words):
category = "coding"
elif any(w in text for w in current_words):
category = "devops"
elif any(w in text for w in research_words + prior_words):
category = "research"
else:
category = "chat"
if "remember" in text or "preference" in text:
memory_candidate = "durable_user_fact"
elif "convention" in text or "workflow" in text:
memory_candidate = "workflow_convention"
else:
memory_candidate = "none"
urgency = "high" if any(w in text for w in ["urgent", "critical", "down", "broken"]) else "normal"
return {
"tool_needed": {"value": tool_needed, "confidence": 0.76 if tool_needed else 0.68},
"memory_candidate": {"value": memory_candidate, "confidence": 0.8 if memory_candidate != "none" else 0.35},
"urgency": {"value": urgency, "confidence": 0.8 if urgency == "high" else 0.65},
"workflow_category": {"value": category, "confidence": 0.78 if category != "chat" else 0.7},
"safety_confirmation_required": {"value": safety, "confidence": 0.9 if safety else 0.2},
}
class _NoClassifierRedirectHandler(urllib.request.HTTPRedirectHandler):
"""Fail closed instead of following redirects away from a validated local URL."""
def redirect_request(self, req, fp, code, msg, headers, newurl): # type: ignore[no-untyped-def]
return None
_CLASSIFIER_OPENER = urllib.request.build_opener(_NoClassifierRedirectHandler)
def classify_live(
query: str,
context: Mapping[str, Any] | None = None,
classifier_url: str = DEFAULT_CLASSIFIER_URL,
timeout: float = 8.0,
) -> ClassifierResult:
classifier_url = validate_classifier_url(classifier_url)
before = read_npu_busy_time_us()
payload = {
"id": f"context-gate-{int(time.time())}",
"text": query,
"context": {"platform": (context or {}).get("platform", "cli"), "source": "context_gate"},
"options": {"include_evidence": False, "include_embedding_debug": False, "dry_run": True},
}
req = urllib.request.Request(
classifier_url,
data=json.dumps(payload).encode("utf-8"),
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with _CLASSIFIER_OPENER.open(req, timeout=timeout) as resp: # noqa: S310 - local configured endpoint only
raw = resp.read(256_000)
except (urllib.error.URLError, TimeoutError, OSError) as exc:
raise ContextGateError(f"classifier_unavailable: {exc}") from exc
after = read_npu_busy_time_us()
try:
data = json.loads(raw.decode("utf-8"))
except json.JSONDecodeError as exc:
raise ContextGateError("classifier_invalid_json") from exc
labels = data.get("labels")
if not isinstance(labels, Mapping):
raise ContextGateError("classifier_missing_labels")
outer = after - before if before is not None and after is not None else None
return ClassifierResult(
labels=labels,
npu_busy_delta_us=_as_int_or_none(data.get("npu_busy_delta_us")),
sysfs_npu_busy_delta_us=_as_int_or_none(data.get("sysfs_npu_busy_delta_us")),
outer_sysfs_delta_us=outer,
live=True,
)
def validate_classifier_url(classifier_url: str) -> str:
"""Validate the local-only classifier endpoint before any POST is attempted."""
parsed = urllib.parse.urlparse(classifier_url)
if parsed.scheme not in {"http", "https"}:
raise ContextGateError("invalid_classifier_url:scheme_must_be_http_or_https")
host = parsed.hostname
if not host:
raise ContextGateError("invalid_classifier_url:missing_host")
host_normalized = host.lower().rstrip(".")
if host_normalized == "localhost":
return classifier_url
try:
address = ipaddress.ip_address(host_normalized)
except ValueError as exc:
raise ContextGateError("invalid_classifier_url:host_must_be_loopback") from exc
if not address.is_loopback:
raise ContextGateError("invalid_classifier_url:host_must_be_loopback")
return classifier_url
def _as_int_or_none(value: Any) -> int | None:
try:
return int(value)
except (TypeError, ValueError):
return None
def classify_offline(query: str, context: Mapping[str, Any] | None = None, warning: str | None = None) -> ClassifierResult:
return ClassifierResult(
labels=heuristic_labels(query, context),
npu_busy_delta_us=None,
sysfs_npu_busy_delta_us=None,
outer_sysfs_delta_us=None,
live=False,
warning=warning or "offline_heuristic_classifier_no_npu_claim",
)
def _has_any(text: str, needles: list[str]) -> bool:
return any(n in text for n in needles)
def _source(source: str, action: str, reason: str, priority: int, freshness: str, confidence: float) -> dict[str, Any]:
assert source in _ALLOWED_SOURCES
return {
"source": source,
"action": action,
"reason": reason,
"priority": priority,
"freshness": freshness,
"permission": "tool_required_by_authoritative_agent" if source != "no_retrieval" else "none",
"missing_behavior": "retrieve_or_mark_missing" if source != "no_retrieval" else "skip_retrieval",
"confidence": round(confidence, 2),
}
def select_sources(query: str, labels: Mapping[str, Any], context: Mapping[str, Any], max_sources: int) -> list[dict[str, Any]]:
text = query.lower()
sources: list[dict[str, Any]] = []
category = str(_label_value(labels, "workflow_category", "unknown"))
memory_candidate = str(_label_value(labels, "memory_candidate", "none"))
tool_needed = bool(_label_value(labels, "tool_needed", False))
if tool_needed or _has_any(text, ["current", "now", "health", "port", "process", "systemd", "status", "npu", "listening", "logs", "time", "date"]):
sources.append(_source("live_system", "inspect_with_terminal_or_domain_tool", "current service/system state requested", 1, "live_required", 0.9))
if context.get("repo_path") or category == "coding" or _has_any(text, ["repo", "code", "file", "test", "pytest", "diff", "implementation", "hermes", "atlas"]):
sources.append(_source("repo_files", "inspect_explicit_repo_paths", "repo-specific implementation or config context", 2, "current_filesystem", 0.84))
if _has_any(text, ["where did we leave", "what did we decide", "previous", "earlier", "handoff", "prior", "last time"]):
sources.append(_source("session_search", "search_prior_sessions_or_kanban_handoffs", "prior decision or handoff requested", 3, "session-era", 0.82))
if _has_any(text, ["runbook", "note", "obsidian", "rag", "docs", "knowledge", "plan"]):
sources.append(_source("rag_search", "query_local_index_read_only", "local docs or indexed knowledge likely useful", 4, "cached_index", 0.76))
if memory_candidate != "none" or _has_any(text, ["preference", "remember", "profile", "durable fact"]):
sources.append(_source("durable_memory", "read_stable_facts_only", "stable preference/environment facts may be relevant", 5, "static", 0.72))
if _has_any(text, ["latest", "news", "version", "release", "public", "web"]):
sources.append(_source("web", "search_public_current_sources", "current external public fact requested", 6, "live_external", 0.7))
if not sources:
sources.append(_source("no_retrieval", "answer_directly", "no factual retrieval dependency detected", 1, "none", 0.78))
# Stable priority order and bounded compact plan.
seen: set[str] = set()
deduped = []
for item in sorted(sources, key=lambda x: x["priority"]):
if item["source"] not in seen:
seen.add(item["source"])
deduped.append(item)
return deduped[:max_sources]
def select_bundle_name(query: str, labels: Mapping[str, Any], context: Mapping[str, Any]) -> str:
text = query.lower()
category = str(_label_value(labels, "workflow_category", "unknown"))
if context.get("platform") == "kanban" or context.get("task_id") or category == "coding":
return "CodingTaskBundle"
if category in {"devops", "debugging"} or _has_any(text, ["health", "port", "systemd", "npu", "service", "logs"]):
return "OpsDebugBundle"
if category in {"note_taking", "productivity"} or _has_any(text, ["preference", "remember", "profile"]):
return "PersonalAssistantBundle"
if "no_retrieval" in [s["source"] for s in select_sources(query, labels, context, 1)]:
return "SimpleResponseBundle"
return "ResearchBundle"
def _field(field: str, shape: str, source: str, freshness: str, missing: str, privacy: str, confidence: float = 0.8) -> dict[str, Any]:
return {
"field": field,
"shape": shape,
"source_of_truth": source,
"freshness": freshness,
"provenance_required": True,
"missing_behavior": missing,
"privacy": privacy,
"confidence": round(confidence, 2),
}
def build_bundle_plan(bundle_name: str, sources: Sequence[Mapping[str, Any]], query: str, labels: Mapping[str, Any]) -> dict[str, Any]:
safety_required = bool(_label_value(labels, "safety_confirmation_required", False))
source_names = {s["source"] for s in sources}
if bundle_name == "OpsDebugBundle":
required = [
_field("problem_statement", "compact_text", "user", "request", "mark_missing", "query_text_only"),
_field("target_scope", "service_repo_or_host", "query_or_classifier", "request", "ask_or_infer_low_confidence", "no_private_paths_beyond_explicit"),
_field("live_state", "status_table", "live_system", "live_required", "retrieve_or_fail_closed", "no_raw_logs_by_default"),
_field("safety_gates", "closed_gate_map", "policy", "static", "fail_closed", "no_private_data"),
_field("provenance", "tool_names_and_paths", "executing_agent", "run", "mark_missing", "paths_only"),
]
elif bundle_name == "CodingTaskBundle":
required = [
_field("repo_root", "absolute_path", "task_or_context", "current", "ask_or_fail", "explicit_path_only"),
_field("git_state", "branch_dirty_counts", "live_system", "live_required", "retrieve_or_fail_closed", "no_diff_dump_by_default"),
_field("requirements", "bullet_summary", "user_kanban_files", "current", "retrieve_or_mark_missing", "no_private_snippets"),
_field("relevant_paths", "path_list", "repo_files", "current_filesystem", "search_narrowly", "paths_only"),
_field("tests_or_smokes", "command_list", "repo_files", "current_filesystem", "mark_missing", "commands_only"),
_field("review_gates", "closed_gate_map", "policy", "static", "fail_closed", "no_private_data"),
]
elif bundle_name == "PersonalAssistantBundle":
required = [
_field("user_intent", "compact_text", "user", "request", "mark_missing", "query_text_only"),
_field("durable_facts_needed", "fact_keys", "durable_memory", "static", "retrieve_or_mark_missing", "no_raw_memory_dump"),
_field("prior_decisions_needed", "session_refs", "session_search", "session-era", "retrieve_or_mark_missing", "summaries_only"),
_field("privacy_boundary", "closed_gate_map", "policy", "static", "fail_closed", "no_private_data"),
_field("action_authority", "closed_gate_map", "policy", "static", "fail_closed", "no_private_data"),
]
elif bundle_name == "SimpleResponseBundle":
required = []
else:
required = [
_field("research_question", "compact_text", "user", "request", "mark_missing", "query_text_only"),
_field("source_plan", "ordered_source_list", "context_gate", "run", "mark_missing", "no_private_snippets"),
_field("evidence_requirements", "provenance_rules", "policy", "static", "fail_closed", "no_private_data"),
_field("freshness_cutoff", "freshness_policy", "classifier_query", "request", "mark_missing", "no_private_data"),
_field("missing_data_behavior", "policy_enum", "policy", "static", "fail_closed", "no_private_data"),
]
blocked = []
if safety_required or re.search(r"\b(route|routing|restart|send|write memory|reindex|delete|mutate)\b", query.lower()):
blocked.append(_field("authority_side_effect", "approval_required", "policy", "static", "fail_closed", "no_side_effects_in_v1", 0.95))
if "rag_search" in source_names:
blocked.append(_field("vector_db_mutation", "not_allowed", "policy", "static", "fail_closed", "read_only_query_plan", 0.95))
return {"bundle_name": bundle_name, "required_fields": required, "optional_fields": [], "blocked_fields": blocked}
def summarize_query_class(labels: Mapping[str, Any]) -> dict[str, Any]:
return {
"workflow_category": _label_value(labels, "workflow_category", "unknown"),
"urgency": _label_value(labels, "urgency", "normal"),
"tool_needed": bool(_label_value(labels, "tool_needed", False)),
"memory_candidate": _label_value(labels, "memory_candidate", "none"),
"safety_confirmation_required": bool(_label_value(labels, "safety_confirmation_required", False)),
"confidence": round(max(
_label_confidence(labels, "workflow_category", 0.5),
_label_confidence(labels, "tool_needed", 0.5),
_label_confidence(labels, "safety_confirmation_required", 0.5),
), 2),
}
def npu_proof_from_classifier(result: ClassifierResult, require_npu_proof: bool) -> tuple[dict[str, Any], list[str]]:
endpoint_delta = result.npu_busy_delta_us
endpoint_sysfs_delta = result.sysfs_npu_busy_delta_us
outer_delta = result.outer_sysfs_delta_us
positive_endpoint_sysfs = endpoint_sysfs_delta is not None and endpoint_sysfs_delta > 0
positive_outer = outer_delta is not None and outer_delta > 0
verified = bool(result.live and (positive_endpoint_sysfs or positive_outer))
warnings: list[str] = []
if result.warning:
warnings.append(result.warning)
if require_npu_proof and not verified:
warnings.append("npu_proof_inconclusive")
return {
"classifier_delta_us": endpoint_delta,
"classifier_sysfs_delta_us": endpoint_sysfs_delta,
"outer_sysfs_delta_us": outer_delta,
"rerank_delta_us": None,
"verified": verified,
"required": require_npu_proof,
"classifier_live": result.live,
}, warnings
def build_plan(
query: str,
*,
context: Mapping[str, Any] | None = None,
options: Mapping[str, Any] | None = None,
classifier: ClassifierResult | None = None,
) -> dict[str, Any]:
if not query or not query.strip():
raise ContextGateError("query_required")
context = dict(context or {})
options = dict(options or {})
if options.get("dry_run", True) is not True:
raise ContextGateError("dry_run_must_remain_true_in_v1")
if options.get("include_private_text", False):
raise ContextGateError("include_private_text_not_allowed_in_v1")
max_sources = max(1, min(6, int(options.get("max_sources", 4))))
require_npu = bool(options.get("require_npu_proof", True))
if classifier is None:
classifier = classify_offline(query, context)
labels = classifier.labels
source_plan = select_sources(query, labels, context, max_sources)
bundle_name = select_bundle_name(query, labels, context)
npu_proof, warnings = npu_proof_from_classifier(classifier, require_npu)
plan = {
"schema": SCHEMA,
"trace_id": options.get("trace_id") or context.get("trace_id"),
"dry_run": True,
"ok": True,
"query_class": summarize_query_class(labels),
"source_plan": source_plan,
"bundle_plan": build_bundle_plan(bundle_name, source_plan, query, labels),
"npu_proof": npu_proof,
"authority": dict(AUTHORITY),
"gates": dict(GATES),
"warnings": warnings,
}
validate_plan(plan)
return plan
def validate_plan(plan: Mapping[str, Any]) -> None:
if plan.get("schema") != SCHEMA:
raise ContextGateError("invalid_schema")
if plan.get("dry_run") is not True:
raise ContextGateError("dry_run_missing")
if plan.get("authority") != AUTHORITY:
raise ContextGateError("authority_not_closed")
sources = plan.get("source_plan")
if not isinstance(sources, list) or not sources:
raise ContextGateError("source_plan_required")
for item in sources:
if item.get("source") not in _ALLOWED_SOURCES:
raise ContextGateError(f"invalid_source:{item.get('source')}")
required_blocks = ["query_class", "bundle_plan", "npu_proof", "gates"]
for block in required_blocks:
if block not in plan:
raise ContextGateError(f"missing_block:{block}")
def compact_line(plan: Mapping[str, Any]) -> str:
sources = ",".join(str(s["source"]) for s in plan["source_plan"])
closed = "route,memory,send,tools,restart,vector,private_roots,config"
warnings = ",".join(plan.get("warnings") or []) or "none"
return (
f"ok={str(plan['ok']).lower()} schema={plan['schema']} "
f"bundle={plan['bundle_plan']['bundle_name']} sources={sources} "
f"source_count={len(plan['source_plan'])} "
f"npu_verified={str(plan['npu_proof']['verified']).lower()} "
f"classifier_delta_us={plan['npu_proof'].get('classifier_delta_us')} "
f"outer_sysfs_delta_us={plan['npu_proof'].get('outer_sysfs_delta_us')} "
f"gates=closed:{closed} warnings={warnings}"
)
def compact_json(plan: Mapping[str, Any]) -> str:
compact = {
"schema": plan["schema"],
"ok": plan["ok"],
"dry_run": plan["dry_run"],
"bundle_name": plan["bundle_plan"]["bundle_name"],
"sources": [s["source"] for s in plan["source_plan"]],
"source_count": len(plan["source_plan"]),
"query_class": plan["query_class"],
"npu_proof": plan["npu_proof"],
"authority": plan["authority"],
"gates_closed": list(plan["gates"].keys()),
"warnings": plan.get("warnings", []),
}
return json.dumps(compact, sort_keys=True, separators=(",", ":"))