Files
janssen/mcp_postgres.py
T
2026-06-04 11:40:45 +02:00

360 lines
12 KiB
Python

#!/usr/bin/env python3
"""
MCP server pro PostgreSQL — používá FastMCP + psycopg v3.
Spustit: python mcp_postgres.py
Host: 192.168.1.76:5432
User: vladimir.buzalka
Default DB: postgres (lze přepínat tool argumentem `db`)
"""
import json
import sys
import time
import traceback
from datetime import datetime, date
from decimal import Decimal
from typing import Optional, Union
import psycopg
from psycopg.rows import dict_row
from mcp.server.fastmcp import FastMCP
PG_HOST = "192.168.1.76"
PG_PORT = 5432
PG_USER = "vladimir.buzalka"
PG_PASSWORD = "Vlado7309208104++"
PG_DEFAULT_DB = "postgres"
WRITE_KEYWORDS = ("insert", "update", "delete", "drop", "truncate",
"alter", "create", "grant", "revoke", "comment", "merge")
def log(msg: str):
print(msg, file=sys.stderr, flush=True)
_pool: dict[str, psycopg.Connection] = {}
def get_conn(db: Optional[str] = None) -> psycopg.Connection:
"""Lazy connection cache keyed by db name. Reconnects if closed."""
name = db or PG_DEFAULT_DB
conn = _pool.get(name)
if conn is not None and not conn.closed:
return conn
conn = psycopg.connect(
host=PG_HOST,
port=PG_PORT,
user=PG_USER,
password=PG_PASSWORD,
dbname=name,
connect_timeout=5,
autocommit=True,
)
_pool[name] = conn
return conn
def serialize(value):
if isinstance(value, (datetime, date)):
return value.isoformat()
if isinstance(value, Decimal):
return float(value)
if isinstance(value, (bytes, memoryview)):
try:
return bytes(value).decode("utf-8", errors="replace")
except Exception:
return repr(value)
if isinstance(value, dict):
return {k: serialize(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
return [serialize(v) for v in value]
return value
def serialize_rows(rows):
return [{k: serialize(v) for k, v in row.items()} for row in rows]
def is_write(sql: str) -> bool:
s = sql.lstrip().lower()
return any(s.startswith(kw) for kw in WRITE_KEYWORDS)
# Verify default DB on startup
try:
c = get_conn()
with c.cursor() as cur:
cur.execute("select version()")
ver = cur.fetchone()[0]
log(f"Connected to PostgreSQL ({PG_HOST}:{PG_PORT}) — {ver}")
except Exception as e:
log(f"PostgreSQL connection failed: {e}")
sys.exit(1)
mcp = FastMCP("janssen-postgres")
@mcp.tool()
def ping() -> dict:
"""Check if PostgreSQL is reachable. Returns server version, latency, host."""
try:
start = time.monotonic()
conn = get_conn()
with conn.cursor() as cur:
cur.execute("select version()")
version = cur.fetchone()[0]
latency_ms = round((time.monotonic() - start) * 1000, 1)
return {
"status": "ok",
"version": version,
"latency_ms": latency_ms,
"host": PG_HOST,
"port": PG_PORT,
}
except Exception as e:
return {"status": "error", "error": str(e), "host": PG_HOST}
@mcp.tool()
def list_databases() -> dict:
"""List all non-template databases on the server."""
try:
conn = get_conn()
with conn.cursor(row_factory=dict_row) as cur:
cur.execute("select datname from pg_database where datistemplate = false order by datname")
dbs = [r["datname"] for r in cur.fetchall()]
return {"count": len(dbs), "databases": dbs}
except Exception:
log(f"list_databases error: {traceback.format_exc()}")
raise
@mcp.tool()
def list_schemas(db: Optional[str] = None) -> dict:
"""List schemas in a database (excludes pg_* and information_schema)."""
try:
conn = get_conn(db)
with conn.cursor(row_factory=dict_row) as cur:
cur.execute("""
select schema_name
from information_schema.schemata
where schema_name not like 'pg\\_%' escape '\\'
and schema_name <> 'information_schema'
order by schema_name
""")
schemas = [r["schema_name"] for r in cur.fetchall()]
return {"db": db or PG_DEFAULT_DB, "count": len(schemas), "schemas": schemas}
except Exception:
log(f"list_schemas error: {traceback.format_exc()}")
raise
@mcp.tool()
def list_tables(db: Optional[str] = None, schema: str = "public") -> dict:
"""List tables (and views) in a schema."""
try:
conn = get_conn(db)
with conn.cursor(row_factory=dict_row) as cur:
cur.execute("""
select table_name, table_type
from information_schema.tables
where table_schema = %s
order by table_name
""", (schema,))
rows = cur.fetchall()
return {
"db": db or PG_DEFAULT_DB,
"schema": schema,
"count": len(rows),
"tables": rows,
}
except Exception:
log(f"list_tables error: {traceback.format_exc()}")
raise
@mcp.tool()
def describe_table(table: str, db: Optional[str] = None, schema: str = "public") -> dict:
"""Return columns, types, nullability, defaults, indexes and row count for a table."""
try:
conn = get_conn(db)
with conn.cursor(row_factory=dict_row) as cur:
cur.execute("""
select column_name, data_type, is_nullable, column_default, character_maximum_length
from information_schema.columns
where table_schema = %s and table_name = %s
order by ordinal_position
""", (schema, table))
columns = cur.fetchall()
cur.execute("""
select indexname, indexdef
from pg_indexes
where schemaname = %s and tablename = %s
order by indexname
""", (schema, table))
indexes = cur.fetchall()
qualified = f'"{schema}"."{table}"'
try:
cur.execute(f"select count(*) as c from {qualified}")
row_count = cur.fetchone()["c"]
except Exception as e:
row_count = f"error: {e}"
return {
"db": db or PG_DEFAULT_DB,
"schema": schema,
"table": table,
"row_count": row_count,
"columns": columns,
"indexes": indexes,
}
except Exception:
log(f"describe_table error: {traceback.format_exc()}")
raise
@mcp.tool()
def query(
sql: str,
db: Optional[str] = None,
params_json: Optional[Union[str, list]] = None,
limit: int = 100,
) -> dict:
"""Run a READ-ONLY SQL query (SELECT / WITH / EXPLAIN / SHOW).
For writes use `execute` (which requires `confirmed=True`).
params_json: JSON array of positional parameters for %s placeholders, e.g. '["CZ1-01", 5]'.
limit: cap on returned rows (max 1000). The query itself is NOT modified — limit is a Python-side cutoff.
"""
try:
if is_write(sql):
return {
"status": "rejected",
"reason": "Write/DDL statement detected. Use `execute` (with preview + confirmed=True).",
}
limit = min(max(limit, 1), 1000)
params = json.loads(params_json) if isinstance(params_json, str) else (params_json or [])
conn = get_conn(db)
with conn.cursor(row_factory=dict_row) as cur:
cur.execute(sql, params)
if cur.description is None:
return {"status": "ok", "rowcount": cur.rowcount, "rows": []}
rows = cur.fetchmany(limit)
return {
"status": "ok",
"db": db or PG_DEFAULT_DB,
"row_count": len(rows),
"limit": limit,
"rows": serialize_rows(rows),
}
except Exception as e:
log(f"query error: {traceback.format_exc()}")
return {"status": "error", "error": str(e), "sql": sql}
@mcp.tool()
def explain(sql: str, db: Optional[str] = None, analyze: bool = False) -> dict:
"""Return EXPLAIN (ANALYZE, BUFFERS) plan for a SELECT-style query. analyze=False is safe; True actually runs the query."""
try:
if is_write(sql) and analyze:
return {"status": "rejected", "reason": "Refuse to EXPLAIN ANALYZE a write statement."}
prefix = "explain (analyze, buffers, format json) " if analyze else "explain (format json) "
conn = get_conn(db)
with conn.cursor() as cur:
cur.execute(prefix + sql)
plan = cur.fetchone()[0]
return {"status": "ok", "analyze": analyze, "plan": plan}
except Exception as e:
log(f"explain error: {traceback.format_exc()}")
return {"status": "error", "error": str(e)}
@mcp.tool()
def preview_execute(
sql: str,
db: Optional[str] = None,
params_json: Optional[Union[str, list]] = None,
) -> dict:
"""Preview a write/DDL statement using a transaction + ROLLBACK. Reports rowcount that WOULD be affected.
Always call this first and present the result to the user before calling `execute` with confirmed=True.
"""
try:
params = json.loads(params_json) if isinstance(params_json, str) else (params_json or [])
# Use a fresh, autocommit=False connection so we can ROLLBACK safely
with psycopg.connect(
host=PG_HOST, port=PG_PORT, user=PG_USER, password=PG_PASSWORD,
dbname=db or PG_DEFAULT_DB, connect_timeout=5, autocommit=False,
) as conn:
with conn.cursor() as cur:
cur.execute(sql, params)
rowcount = cur.rowcount
conn.rollback()
return {
"status": "preview",
"db": db or PG_DEFAULT_DB,
"sql": sql,
"params": params,
"would_affect_rows": rowcount,
"note": "Statement was executed inside a transaction and rolled back. Present this to the user and call `execute` with confirmed=True only after explicit approval.",
}
except Exception as e:
log(f"preview_execute error: {traceback.format_exc()}")
return {"status": "error", "error": str(e), "sql": sql}
@mcp.tool()
def execute(
sql: str,
db: Optional[str] = None,
params_json: Optional[Union[str, list]] = None,
confirmed: bool = False,
) -> dict:
"""Run a write/DDL statement (INSERT/UPDATE/DELETE/CREATE/ALTER/DROP/...).
REQUIRES confirmed=True — first call `preview_execute` and obtain explicit user approval.
"""
if not confirmed:
return {
"status": "aborted",
"reason": "confirmed=False. Call preview_execute first, present the impact to the user, then re-call with confirmed=True.",
}
try:
params = json.loads(params_json) if isinstance(params_json, str) else (params_json or [])
conn = get_conn(db)
with conn.cursor() as cur:
cur.execute(sql, params)
rowcount = cur.rowcount
log(f"execute: db={db or PG_DEFAULT_DB} rowcount={rowcount}")
return {
"status": "ok",
"db": db or PG_DEFAULT_DB,
"rowcount": rowcount,
"sql": sql,
}
except Exception as e:
log(f"execute error: {traceback.format_exc()}")
return {"status": "error", "error": str(e), "sql": sql}
@mcp.tool()
def list_extensions(db: Optional[str] = None) -> dict:
"""List installed extensions in the database."""
try:
conn = get_conn(db)
with conn.cursor(row_factory=dict_row) as cur:
cur.execute("select extname, extversion from pg_extension order by extname")
rows = cur.fetchall()
return {"db": db or PG_DEFAULT_DB, "count": len(rows), "extensions": rows}
except Exception:
log(f"list_extensions error: {traceback.format_exc()}")
raise
if __name__ == "__main__":
log("MCP PostgreSQL server started (FastMCP)")
mcp.run()