Files
swarm-zap/scripts/sync-litellm-models.py

369 lines
14 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import json
import os
import shutil
import sys
import urllib.request
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
CONFIG_PATH = Path.home() / ".openclaw" / "openclaw.json"
WORKSPACE = Path.home() / ".openclaw" / "workspace"
METADATA_PATH = WORKSPACE / "models" / "litellm-official-metadata.json"
TIMEOUT = 12
FALLBACK_CONTEXT = 200000
FALLBACK_MAX_TOKENS = 8192
def die(msg: str, code: int = 1):
print(f"ERROR: {msg}", file=sys.stderr)
raise SystemExit(code)
def normalize_base(url: str) -> str:
return (url or "").rstrip("/")
def load_json(path: Path) -> dict[str, Any]:
try:
return json.loads(path.read_text(encoding="utf-8"))
except FileNotFoundError:
die(f"File not found: {path}")
except json.JSONDecodeError as e:
die(f"Invalid JSON in {path}: {e}")
def resolve_json_pointer(doc: Any, pointer: str) -> Any:
if pointer in ("", "/"):
return doc
cur = doc
for raw_part in pointer.lstrip("/").split("/"):
part = raw_part.replace("~1", "/").replace("~0", "~")
if isinstance(cur, dict) and part in cur:
cur = cur[part]
else:
raise KeyError(pointer)
return cur
def resolve_api_key(raw_api_key: Any, cfg: dict[str, Any]) -> str | None:
if isinstance(raw_api_key, str) and raw_api_key.strip():
return raw_api_key.strip()
if isinstance(raw_api_key, dict):
source = raw_api_key.get("source")
if source == "env":
name = raw_api_key.get("name") or raw_api_key.get("id")
if isinstance(name, str) and name:
return os.environ.get(name)
if source == "file":
provider_id = raw_api_key.get("provider") or ((cfg.get("secrets") or {}).get("defaults") or {}).get("file")
providers = (((cfg.get("secrets") or {}).get("providers") or {}))
provider = providers.get(provider_id) if isinstance(provider_id, str) else None
if isinstance(provider, dict) and provider.get("source") == "file":
path = provider.get("path")
pointer = raw_api_key.get("id")
if isinstance(path, str) and isinstance(pointer, str):
try:
secret_doc = load_json(Path(path))
value = resolve_json_pointer(secret_doc, pointer)
if isinstance(value, str) and value:
return value
except Exception:
pass
return (
os.environ.get("LITELLM_API_KEY")
or os.environ.get("OPENAI_API_KEY")
or None
)
def fetch_json(url: str, api_key: str | None):
req = urllib.request.Request(url, method="GET")
req.add_header("Accept", "application/json")
if api_key:
req.add_header("Authorization", f"Bearer {api_key}")
with urllib.request.urlopen(req, timeout=TIMEOUT) as resp:
return json.loads(resp.read().decode("utf-8", errors="replace"))
def fetch_model_detail(base_root: str, model_id: str, api_key: str | None):
from urllib.parse import quote
try:
return fetch_json(f"{base_root}/models/{quote(model_id, safe='')}", api_key)
except Exception:
return None
def fetch_models_and_info(base_url: str, api_key: str | None):
url = normalize_base(base_url)
if not url:
die("litellm.baseUrl is empty")
if not url.endswith("/v1"):
url = f"{url}/v1"
payload = fetch_json(f"{url}/models", api_key)
if isinstance(payload, dict) and isinstance(payload.get("data"), list):
rows = payload["data"]
elif isinstance(payload, list):
rows = payload
else:
die(f"Unexpected /models payload shape: {type(payload).__name__}")
model_ids: list[str] = []
model_rows: dict[str, dict[str, Any]] = {}
for row in rows:
if not isinstance(row, dict):
continue
mid = row.get("id") or row.get("model")
if isinstance(mid, str) and mid.strip():
mid = mid.strip()
if mid not in model_rows:
model_ids.append(mid)
model_rows[mid] = row
model_info = {}
info_error = None
base_root = normalize_base(base_url).removesuffix('/v1')
detail_errors = 0
for mid in model_ids:
detail = fetch_model_detail(base_root, mid, api_key)
if isinstance(detail, dict):
model_info[mid] = detail
else:
detail_errors += 1
if detail_errors:
info_error = f"model detail unavailable for {detail_errors}/{len(model_ids)} models"
return model_ids, model_rows, model_info, info_error
def load_metadata(path: Path) -> dict[str, Any]:
if not path.exists():
return {"models": {}}
data = load_json(path)
models = data.get("models")
if not isinstance(models, dict):
die(f"Metadata file {path} must contain an object at key 'models'")
return data
def pick_model_info(model_id: str, model_rows: dict[str, dict[str, Any]], model_info: Any):
row = model_rows.get(model_id) or {}
info = None
if isinstance(model_info, dict):
if isinstance(model_info.get("data"), list):
for item in model_info["data"]:
if isinstance(item, dict) and item.get("model_name") == model_id:
info = item
break
if info is None and isinstance(model_info.get(model_id), dict):
info = model_info.get(model_id)
return row, info
def clean_input(value: Any) -> list[str] | None:
if isinstance(value, list):
out = [x for x in value if isinstance(x, str) and x]
return out or None
return None
def metadata_from_litellm(model_id: str, model_rows: dict[str, dict[str, Any]], model_info: Any) -> dict[str, Any]:
row, info = pick_model_info(model_id, model_rows, model_info)
out: dict[str, Any] = {}
for source in (row, info):
if not isinstance(source, dict):
continue
for src_key, dst_key in (
("context_window", "contextWindow"),
("max_input_tokens", "contextWindow"),
("max_output_tokens", "maxTokens"),
("max_tokens", "maxTokens"),
):
val = source.get(src_key)
if isinstance(val, int) and val > 0 and dst_key not in out:
out[dst_key] = val
inp = clean_input(source.get("input_types") or source.get("input"))
if inp and "input" not in out:
out["input"] = inp
reasoning = source.get("supports_reasoning")
if isinstance(reasoning, bool) and "reasoning" not in out:
out["reasoning"] = reasoning
if out:
out["source"] = "litellm-api"
return out
def official_alias_metadata(model_id: str, official_models: dict[str, Any]) -> dict[str, Any]:
direct = official_models.get(model_id)
if isinstance(direct, dict):
return direct
if model_id.startswith('copilot-'):
base = model_id[len('copilot-'):]
base_meta = official_models.get(base)
if isinstance(base_meta, dict):
out = dict(base_meta)
out['source'] = f"alias:{base_meta.get('source', 'official')}"
return out
alias_map = {
'gemini-flash-latest': 'gemini-2.5-flash',
'gemini-flash-lite-latest': 'gemini-2.5-flash-lite',
'gemini-pro-latest': 'gemini-2.5-pro',
'gemini-3-flash-preview': 'gemini-2.5-flash',
'gemini-3-pro-preview': 'gemini-2.5-pro',
'gemini-3.1-pro-preview': 'gemini-2.5-pro',
'gpt-5.1-codex-max': 'gpt-5.1-codex',
}
base = alias_map.get(model_id)
if base:
base_meta = official_models.get(base)
if isinstance(base_meta, dict):
out = dict(base_meta)
out['source'] = f"alias:{base_meta.get('source', 'official')}"
return out
return {}
def merge_metadata(existing: dict[str, Any], official: dict[str, Any], litellm_meta: dict[str, Any], model_id: str) -> tuple[dict[str, Any], str]:
merged = dict(existing)
merged.pop("metadataSource", None)
merged["id"] = model_id
merged.setdefault("name", existing.get("name") if isinstance(existing.get("name"), str) else model_id)
source_used = "existing-config"
for field in ("contextWindow", "maxTokens", "input", "reasoning"):
if field in official and official[field] not in (None, [], ""):
merged[field] = official[field]
source_used = official.get("source", "official-metadata")
elif field in litellm_meta and litellm_meta[field] not in (None, [], "") and field not in merged:
merged[field] = litellm_meta[field]
source_used = litellm_meta.get("source", "litellm-api")
elif field not in merged:
if field == "contextWindow":
merged[field] = FALLBACK_CONTEXT
source_used = "fallback-default"
elif field == "maxTokens":
merged[field] = FALLBACK_MAX_TOKENS
source_used = "fallback-default"
elif field == "input":
merged[field] = ["text"]
elif field == "reasoning":
merged[field] = False
return merged, source_used
def build_sync_report(models: list[dict[str, Any]], official_meta: dict[str, Any], source_map: dict[str, str]):
fallback = [mid for mid, src in source_map.items() if src == "fallback-default"]
from_official = [mid for mid, src in source_map.items() if src.startswith("official-")]
alias_derived = [mid for mid, src in source_map.items() if src.startswith("alias:")]
unresolved = list(fallback)
return {
"total": len(models),
"officialCount": len(from_official),
"aliasDerivedCount": len(alias_derived),
"fallbackCount": len(fallback),
"fallbackModels": fallback,
"missingOfficialMetadata": unresolved,
}
def main():
parser = argparse.ArgumentParser(description="Sync LiteLLM model ids and metadata into OpenClaw config")
parser.add_argument("--config", type=Path, default=CONFIG_PATH)
parser.add_argument("--metadata", type=Path, default=METADATA_PATH)
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--audit-only", action="store_true")
parser.add_argument("--json", action="store_true", help="Print report as JSON")
args = parser.parse_args()
cfg = load_json(args.config)
providers = (((cfg.get("models") or {}).get("providers") or {}))
litellm = providers.get("litellm")
if not isinstance(litellm, dict):
die("models.providers.litellm not found")
base_url = litellm.get("baseUrl")
api_key = resolve_api_key(litellm.get("apiKey"), cfg)
model_ids, model_rows, model_info, info_error = fetch_models_and_info(base_url, api_key)
if not model_ids:
die("No models returned from LiteLLM /v1/models")
meta_file = load_metadata(args.metadata)
official_models = meta_file.get("models", {})
if not isinstance(official_models, dict):
die("metadata models must be an object")
existing_models = litellm.get("models") if isinstance(litellm.get("models"), list) else []
existing_by_id = {
m.get("id"): m
for m in existing_models
if isinstance(m, dict) and isinstance(m.get("id"), str)
}
new_models = []
source_map: dict[str, str] = {}
for mid in model_ids:
existing = dict(existing_by_id.get(mid, {}))
official = official_alias_metadata(mid, official_models)
litellm_meta = metadata_from_litellm(mid, model_rows, model_info)
merged, source_used = merge_metadata(existing, official, litellm_meta, mid)
new_models.append(merged)
source_map[mid] = source_used
report = build_sync_report(new_models, official_models, source_map)
if info_error:
report["modelInfoWarning"] = info_error
if not args.audit_only:
litellm["models"] = new_models
defaults = ((cfg.get("agents") or {}).get("defaults") or {})
model_map = defaults.get("models") if isinstance(defaults.get("models"), dict) else {}
preserved = {k: v for k, v in model_map.items() if not k.startswith("litellm/")}
for model in new_models:
key = f"litellm/{model['id']}"
entry = model_map.get(key, {}) if isinstance(model_map.get(key), dict) else {}
preserved[key] = entry
defaults["models"] = preserved
if not args.dry_run:
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
backup = args.config.with_suffix(f".json.bak-{ts}")
shutil.copy2(args.config, backup)
args.config.write_text(json.dumps(cfg, indent=2) + "\n", encoding="utf-8")
report["backup"] = str(backup)
report["updatedConfig"] = str(args.config)
if args.json:
print(json.dumps(report, indent=2))
else:
print(f"Synced {report['total']} LiteLLM models")
print(f"Official metadata: {report['officialCount']}")
print(f"Fallback metadata: {report['fallbackCount']}")
if report.get("missingOfficialMetadata"):
print("Missing official metadata:")
for mid in report["missingOfficialMetadata"][:30]:
print(f" - {mid}")
if len(report["missingOfficialMetadata"]) > 30:
print(f" ... +{len(report['missingOfficialMetadata']) - 30} more")
if report.get("fallbackModels"):
print("Still using fallback defaults:")
for mid in report["fallbackModels"][:30]:
print(f" - {mid}")
if len(report["fallbackModels"]) > 30:
print(f" ... +{len(report['fallbackModels']) - 30} more")
if report.get("modelInfoWarning"):
print(f"LiteLLM /model/info warning: {report['modelInfoWarning']}")
if report.get("backup"):
print(f"Backup: {report['backup']}")
if __name__ == "__main__":
main()