268 lines
8.5 KiB
Python
268 lines
8.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
MCP server pro MongoDB — používá FastMCP.
|
|
Spustit: python mcp_mongo.py
|
|
"""
|
|
|
|
import json
|
|
import sys
|
|
import traceback
|
|
from datetime import datetime, date
|
|
from typing import Optional, Union
|
|
|
|
from bson import ObjectId
|
|
from mcp.server.fastmcp import FastMCP
|
|
from pymongo import MongoClient
|
|
|
|
MONGO_HOST = "192.168.1.76"
|
|
|
|
|
|
def log(msg: str):
|
|
print(msg, file=sys.stderr, flush=True)
|
|
|
|
|
|
try:
|
|
client = MongoClient(MONGO_HOST, serverSelectionTimeoutMS=5000)
|
|
client.server_info()
|
|
log(f"Connected to MongoDB ({MONGO_HOST})")
|
|
except Exception as e:
|
|
log(f"MongoDB connection failed: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
def serialize(obj):
|
|
"""Make MongoDB documents JSON-serializable."""
|
|
if isinstance(obj, ObjectId):
|
|
return str(obj)
|
|
if isinstance(obj, datetime):
|
|
return obj.isoformat()
|
|
if isinstance(obj, date):
|
|
return obj.isoformat()
|
|
if isinstance(obj, bytes):
|
|
return obj.decode("utf-8", errors="replace")
|
|
if isinstance(obj, dict):
|
|
return {k: serialize(v) for k, v in obj.items()}
|
|
if isinstance(obj, list):
|
|
return [serialize(v) for v in obj]
|
|
return obj
|
|
|
|
|
|
def parse_filter(filter_json) -> dict:
|
|
"""Parse JSON string or dict to dict. Returns {} on None/empty."""
|
|
if not filter_json:
|
|
return {}
|
|
if isinstance(filter_json, dict):
|
|
return filter_json
|
|
if isinstance(filter_json, str):
|
|
if not filter_json.strip():
|
|
return {}
|
|
return json.loads(filter_json)
|
|
return {}
|
|
|
|
|
|
mcp = FastMCP("janssen-mongo")
|
|
|
|
|
|
@mcp.tool()
|
|
def list_databases() -> dict:
|
|
"""List all databases on the MongoDB server (excludes admin/config/local)."""
|
|
try:
|
|
skip = {"admin", "config", "local"}
|
|
dbs = [d["name"] for d in client.list_databases() if d["name"] not in skip]
|
|
return {"count": len(dbs), "databases": dbs}
|
|
except Exception:
|
|
log(f"list_databases error: {traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
@mcp.tool()
|
|
def list_collections(db: str) -> dict:
|
|
"""List all collections in a database."""
|
|
try:
|
|
cols = sorted(client[db].list_collection_names())
|
|
return {"db": db, "count": len(cols), "collections": cols}
|
|
except Exception:
|
|
log(f"list_collections error: {traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
@mcp.tool()
|
|
def collection_stats(db: str, collection: str) -> dict:
|
|
"""Returns document count + schema sample (union of fields and types from first 20 documents)."""
|
|
try:
|
|
col = client[db][collection]
|
|
count = col.count_documents({})
|
|
sample_docs = list(col.find().limit(20))
|
|
schema = {}
|
|
for doc in sample_docs:
|
|
for k, v in doc.items():
|
|
if k not in schema:
|
|
schema[k] = type(v).__name__
|
|
return {"db": db, "collection": collection, "count": count, "schema": schema}
|
|
except Exception:
|
|
log(f"collection_stats error: {traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
@mcp.tool()
|
|
def find_documents(
|
|
db: str,
|
|
collection: str,
|
|
filter_json: Optional[Union[str, dict]] = None,
|
|
projection_json: Optional[Union[str, dict]] = None,
|
|
sort_json: Optional[Union[str, dict]] = None,
|
|
limit: int = 50,
|
|
skip: int = 0,
|
|
) -> dict:
|
|
"""Query documents. filter/projection/sort are JSON strings or dicts, e.g. '{"patient_code":"CZ1-01"}'.
|
|
sort example: '{"last_seen_at": -1}'. Limit max 500. Use skip for pagination.
|
|
"""
|
|
try:
|
|
col = client[db][collection]
|
|
filt = parse_filter(filter_json)
|
|
proj = parse_filter(projection_json) or None
|
|
sort_spec = parse_filter(sort_json)
|
|
limit = min(limit, 500)
|
|
|
|
cursor = col.find(filt, proj)
|
|
if sort_spec:
|
|
cursor = cursor.sort(list(sort_spec.items()))
|
|
cursor = cursor.skip(skip).limit(limit)
|
|
|
|
docs = [serialize(doc) for doc in cursor]
|
|
total = col.count_documents(filt)
|
|
return {"total": total, "skip": skip, "count": len(docs), "docs": docs}
|
|
except Exception:
|
|
log(f"find_documents error: {traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
@mcp.tool()
|
|
def aggregate(db: str, collection: str, pipeline_json: str) -> dict:
|
|
"""Run a MongoDB aggregation pipeline. pipeline_json is a JSON array of stages,
|
|
e.g. '[{"$group": {"_id": "$patient_code", "count": {"$sum": 1}}}]'.
|
|
"""
|
|
try:
|
|
col = client[db][collection]
|
|
pipeline = json.loads(pipeline_json)
|
|
results = [serialize(doc) for doc in col.aggregate(pipeline)]
|
|
return {"count": len(results), "results": results}
|
|
except Exception:
|
|
log(f"aggregate error: {traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
@mcp.tool()
|
|
def field_coverage(db: str, collection: str, field: str) -> dict:
|
|
"""Check how many documents have a given field present and non-null/non-empty.
|
|
Useful for verifying data consistency across a collection.
|
|
Returns total count, how many have the field, how many are missing it, and coverage %.
|
|
"""
|
|
try:
|
|
col = client[db][collection]
|
|
total = col.count_documents({})
|
|
present = col.count_documents({field: {"$exists": True, "$ne": None, "$ne": ""}})
|
|
missing = total - present
|
|
return {
|
|
"field": field,
|
|
"total": total,
|
|
"present": present,
|
|
"missing": missing,
|
|
"coverage_pct": round(present / total * 100, 1) if total else 0,
|
|
}
|
|
except Exception:
|
|
log(f"field_coverage error: {traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
@mcp.tool()
|
|
def count_documents(db: str, collection: str, filter_json: Optional[Union[str, dict]] = None) -> dict:
|
|
"""Count documents matching an optional filter. Fast way to check data without fetching content."""
|
|
try:
|
|
col = client[db][collection]
|
|
filt = parse_filter(filter_json)
|
|
count = col.count_documents(filt)
|
|
return {"db": db, "collection": collection, "filter": filt, "count": count}
|
|
except Exception:
|
|
log(f"count_documents error: {traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
@mcp.tool()
|
|
def schema_diff(
|
|
db: str,
|
|
collection: str,
|
|
filter_a_json: Optional[Union[str, dict]] = None,
|
|
filter_b_json: Optional[Union[str, dict]] = None,
|
|
sample_size: int = 50,
|
|
) -> dict:
|
|
"""Compare field presence between two groups of documents in a collection.
|
|
filter_a_json and filter_b_json define the two groups (defaults to all documents).
|
|
Returns fields that are only in group A, only in group B, and in both — useful for
|
|
diagnosing why some documents are missing fields that others have.
|
|
"""
|
|
try:
|
|
col = client[db][collection]
|
|
filt_a = parse_filter(filter_a_json)
|
|
filt_b = parse_filter(filter_b_json)
|
|
|
|
def get_field_set(filt: dict) -> set:
|
|
fields = set()
|
|
for doc in col.find(filt).limit(sample_size):
|
|
fields.update(doc.keys())
|
|
return fields
|
|
|
|
fields_a = get_field_set(filt_a)
|
|
fields_b = get_field_set(filt_b)
|
|
|
|
only_in_a = sorted(fields_a - fields_b)
|
|
only_in_b = sorted(fields_b - fields_a)
|
|
in_both = sorted(fields_a & fields_b)
|
|
|
|
return {
|
|
"only_in_a": only_in_a,
|
|
"only_in_b": only_in_b,
|
|
"in_both": in_both,
|
|
"count_a": len(fields_a),
|
|
"count_b": len(fields_b),
|
|
}
|
|
except Exception:
|
|
log(f"schema_diff error: {traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
@mcp.tool()
|
|
def ping() -> dict:
|
|
"""Check if MongoDB is reachable. Returns server version and latency."""
|
|
try:
|
|
import time
|
|
start = time.monotonic()
|
|
info = client.server_info()
|
|
latency_ms = round((time.monotonic() - start) * 1000, 1)
|
|
return {
|
|
"status": "ok",
|
|
"version": info.get("version", "unknown"),
|
|
"latency_ms": latency_ms,
|
|
"host": MONGO_HOST,
|
|
}
|
|
except Exception as e:
|
|
return {"status": "error", "error": str(e), "host": MONGO_HOST}
|
|
|
|
|
|
@mcp.tool()
|
|
def distinct_values(db: str, collection: str, field: str, filter_json: Optional[Union[str, dict]] = None) -> dict:
|
|
"""Return distinct values of a field, optionally filtered. Useful for exploring enums and categories."""
|
|
try:
|
|
col = client[db][collection]
|
|
filt = parse_filter(filter_json)
|
|
values = [serialize(v) for v in col.distinct(field, filt)]
|
|
return {"field": field, "count": len(values), "values": values}
|
|
except Exception:
|
|
log(f"distinct_values error: {traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
if __name__ == "__main__":
|
|
log("MCP MongoDB server started (FastMCP)")
|
|
mcp.run()
|