360 lines
12 KiB
Python
360 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
MCP server pro PostgreSQL — používá FastMCP + psycopg v3.
|
|
Spustit: python mcp_postgres.py
|
|
|
|
Host: 192.168.1.76:5432
|
|
User: vladimir.buzalka
|
|
Default DB: postgres (lze přepínat tool argumentem `db`)
|
|
"""
|
|
|
|
import json
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from datetime import datetime, date
|
|
from decimal import Decimal
|
|
from typing import Optional, Union
|
|
|
|
import psycopg
|
|
from psycopg.rows import dict_row
|
|
from mcp.server.fastmcp import FastMCP
|
|
|
|
PG_HOST = "192.168.1.76"
|
|
PG_PORT = 5432
|
|
PG_USER = "vladimir.buzalka"
|
|
PG_PASSWORD = "Vlado7309208104++"
|
|
PG_DEFAULT_DB = "postgres"
|
|
|
|
WRITE_KEYWORDS = ("insert", "update", "delete", "drop", "truncate",
|
|
"alter", "create", "grant", "revoke", "comment", "merge")
|
|
|
|
|
|
def log(msg: str):
|
|
print(msg, file=sys.stderr, flush=True)
|
|
|
|
|
|
_pool: dict[str, psycopg.Connection] = {}
|
|
|
|
|
|
def get_conn(db: Optional[str] = None) -> psycopg.Connection:
|
|
"""Lazy connection cache keyed by db name. Reconnects if closed."""
|
|
name = db or PG_DEFAULT_DB
|
|
conn = _pool.get(name)
|
|
if conn is not None and not conn.closed:
|
|
return conn
|
|
conn = psycopg.connect(
|
|
host=PG_HOST,
|
|
port=PG_PORT,
|
|
user=PG_USER,
|
|
password=PG_PASSWORD,
|
|
dbname=name,
|
|
connect_timeout=5,
|
|
autocommit=True,
|
|
)
|
|
_pool[name] = conn
|
|
return conn
|
|
|
|
|
|
def serialize(value):
|
|
if isinstance(value, (datetime, date)):
|
|
return value.isoformat()
|
|
if isinstance(value, Decimal):
|
|
return float(value)
|
|
if isinstance(value, (bytes, memoryview)):
|
|
try:
|
|
return bytes(value).decode("utf-8", errors="replace")
|
|
except Exception:
|
|
return repr(value)
|
|
if isinstance(value, dict):
|
|
return {k: serialize(v) for k, v in value.items()}
|
|
if isinstance(value, (list, tuple)):
|
|
return [serialize(v) for v in value]
|
|
return value
|
|
|
|
|
|
def serialize_rows(rows):
|
|
return [{k: serialize(v) for k, v in row.items()} for row in rows]
|
|
|
|
|
|
def is_write(sql: str) -> bool:
|
|
s = sql.lstrip().lower()
|
|
return any(s.startswith(kw) for kw in WRITE_KEYWORDS)
|
|
|
|
|
|
# Verify default DB on startup
|
|
try:
|
|
c = get_conn()
|
|
with c.cursor() as cur:
|
|
cur.execute("select version()")
|
|
ver = cur.fetchone()[0]
|
|
log(f"Connected to PostgreSQL ({PG_HOST}:{PG_PORT}) — {ver}")
|
|
except Exception as e:
|
|
log(f"PostgreSQL connection failed: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
mcp = FastMCP("janssen-postgres")
|
|
|
|
|
|
@mcp.tool()
|
|
def ping() -> dict:
|
|
"""Check if PostgreSQL is reachable. Returns server version, latency, host."""
|
|
try:
|
|
start = time.monotonic()
|
|
conn = get_conn()
|
|
with conn.cursor() as cur:
|
|
cur.execute("select version()")
|
|
version = cur.fetchone()[0]
|
|
latency_ms = round((time.monotonic() - start) * 1000, 1)
|
|
return {
|
|
"status": "ok",
|
|
"version": version,
|
|
"latency_ms": latency_ms,
|
|
"host": PG_HOST,
|
|
"port": PG_PORT,
|
|
}
|
|
except Exception as e:
|
|
return {"status": "error", "error": str(e), "host": PG_HOST}
|
|
|
|
|
|
@mcp.tool()
|
|
def list_databases() -> dict:
|
|
"""List all non-template databases on the server."""
|
|
try:
|
|
conn = get_conn()
|
|
with conn.cursor(row_factory=dict_row) as cur:
|
|
cur.execute("select datname from pg_database where datistemplate = false order by datname")
|
|
dbs = [r["datname"] for r in cur.fetchall()]
|
|
return {"count": len(dbs), "databases": dbs}
|
|
except Exception:
|
|
log(f"list_databases error: {traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
@mcp.tool()
|
|
def list_schemas(db: Optional[str] = None) -> dict:
|
|
"""List schemas in a database (excludes pg_* and information_schema)."""
|
|
try:
|
|
conn = get_conn(db)
|
|
with conn.cursor(row_factory=dict_row) as cur:
|
|
cur.execute("""
|
|
select schema_name
|
|
from information_schema.schemata
|
|
where schema_name not like 'pg\\_%' escape '\\'
|
|
and schema_name <> 'information_schema'
|
|
order by schema_name
|
|
""")
|
|
schemas = [r["schema_name"] for r in cur.fetchall()]
|
|
return {"db": db or PG_DEFAULT_DB, "count": len(schemas), "schemas": schemas}
|
|
except Exception:
|
|
log(f"list_schemas error: {traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
@mcp.tool()
|
|
def list_tables(db: Optional[str] = None, schema: str = "public") -> dict:
|
|
"""List tables (and views) in a schema."""
|
|
try:
|
|
conn = get_conn(db)
|
|
with conn.cursor(row_factory=dict_row) as cur:
|
|
cur.execute("""
|
|
select table_name, table_type
|
|
from information_schema.tables
|
|
where table_schema = %s
|
|
order by table_name
|
|
""", (schema,))
|
|
rows = cur.fetchall()
|
|
return {
|
|
"db": db or PG_DEFAULT_DB,
|
|
"schema": schema,
|
|
"count": len(rows),
|
|
"tables": rows,
|
|
}
|
|
except Exception:
|
|
log(f"list_tables error: {traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
@mcp.tool()
|
|
def describe_table(table: str, db: Optional[str] = None, schema: str = "public") -> dict:
|
|
"""Return columns, types, nullability, defaults, indexes and row count for a table."""
|
|
try:
|
|
conn = get_conn(db)
|
|
with conn.cursor(row_factory=dict_row) as cur:
|
|
cur.execute("""
|
|
select column_name, data_type, is_nullable, column_default, character_maximum_length
|
|
from information_schema.columns
|
|
where table_schema = %s and table_name = %s
|
|
order by ordinal_position
|
|
""", (schema, table))
|
|
columns = cur.fetchall()
|
|
|
|
cur.execute("""
|
|
select indexname, indexdef
|
|
from pg_indexes
|
|
where schemaname = %s and tablename = %s
|
|
order by indexname
|
|
""", (schema, table))
|
|
indexes = cur.fetchall()
|
|
|
|
qualified = f'"{schema}"."{table}"'
|
|
try:
|
|
cur.execute(f"select count(*) as c from {qualified}")
|
|
row_count = cur.fetchone()["c"]
|
|
except Exception as e:
|
|
row_count = f"error: {e}"
|
|
|
|
return {
|
|
"db": db or PG_DEFAULT_DB,
|
|
"schema": schema,
|
|
"table": table,
|
|
"row_count": row_count,
|
|
"columns": columns,
|
|
"indexes": indexes,
|
|
}
|
|
except Exception:
|
|
log(f"describe_table error: {traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
@mcp.tool()
|
|
def query(
|
|
sql: str,
|
|
db: Optional[str] = None,
|
|
params_json: Optional[Union[str, list]] = None,
|
|
limit: int = 100,
|
|
) -> dict:
|
|
"""Run a READ-ONLY SQL query (SELECT / WITH / EXPLAIN / SHOW).
|
|
For writes use `execute` (which requires `confirmed=True`).
|
|
params_json: JSON array of positional parameters for %s placeholders, e.g. '["CZ1-01", 5]'.
|
|
limit: cap on returned rows (max 1000). The query itself is NOT modified — limit is a Python-side cutoff.
|
|
"""
|
|
try:
|
|
if is_write(sql):
|
|
return {
|
|
"status": "rejected",
|
|
"reason": "Write/DDL statement detected. Use `execute` (with preview + confirmed=True).",
|
|
}
|
|
limit = min(max(limit, 1), 1000)
|
|
params = json.loads(params_json) if isinstance(params_json, str) else (params_json or [])
|
|
|
|
conn = get_conn(db)
|
|
with conn.cursor(row_factory=dict_row) as cur:
|
|
cur.execute(sql, params)
|
|
if cur.description is None:
|
|
return {"status": "ok", "rowcount": cur.rowcount, "rows": []}
|
|
rows = cur.fetchmany(limit)
|
|
return {
|
|
"status": "ok",
|
|
"db": db or PG_DEFAULT_DB,
|
|
"row_count": len(rows),
|
|
"limit": limit,
|
|
"rows": serialize_rows(rows),
|
|
}
|
|
except Exception as e:
|
|
log(f"query error: {traceback.format_exc()}")
|
|
return {"status": "error", "error": str(e), "sql": sql}
|
|
|
|
|
|
@mcp.tool()
|
|
def explain(sql: str, db: Optional[str] = None, analyze: bool = False) -> dict:
|
|
"""Return EXPLAIN (ANALYZE, BUFFERS) plan for a SELECT-style query. analyze=False is safe; True actually runs the query."""
|
|
try:
|
|
if is_write(sql) and analyze:
|
|
return {"status": "rejected", "reason": "Refuse to EXPLAIN ANALYZE a write statement."}
|
|
prefix = "explain (analyze, buffers, format json) " if analyze else "explain (format json) "
|
|
conn = get_conn(db)
|
|
with conn.cursor() as cur:
|
|
cur.execute(prefix + sql)
|
|
plan = cur.fetchone()[0]
|
|
return {"status": "ok", "analyze": analyze, "plan": plan}
|
|
except Exception as e:
|
|
log(f"explain error: {traceback.format_exc()}")
|
|
return {"status": "error", "error": str(e)}
|
|
|
|
|
|
@mcp.tool()
|
|
def preview_execute(
|
|
sql: str,
|
|
db: Optional[str] = None,
|
|
params_json: Optional[Union[str, list]] = None,
|
|
) -> dict:
|
|
"""Preview a write/DDL statement using a transaction + ROLLBACK. Reports rowcount that WOULD be affected.
|
|
Always call this first and present the result to the user before calling `execute` with confirmed=True.
|
|
"""
|
|
try:
|
|
params = json.loads(params_json) if isinstance(params_json, str) else (params_json or [])
|
|
# Use a fresh, autocommit=False connection so we can ROLLBACK safely
|
|
with psycopg.connect(
|
|
host=PG_HOST, port=PG_PORT, user=PG_USER, password=PG_PASSWORD,
|
|
dbname=db or PG_DEFAULT_DB, connect_timeout=5, autocommit=False,
|
|
) as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, params)
|
|
rowcount = cur.rowcount
|
|
conn.rollback()
|
|
return {
|
|
"status": "preview",
|
|
"db": db or PG_DEFAULT_DB,
|
|
"sql": sql,
|
|
"params": params,
|
|
"would_affect_rows": rowcount,
|
|
"note": "Statement was executed inside a transaction and rolled back. Present this to the user and call `execute` with confirmed=True only after explicit approval.",
|
|
}
|
|
except Exception as e:
|
|
log(f"preview_execute error: {traceback.format_exc()}")
|
|
return {"status": "error", "error": str(e), "sql": sql}
|
|
|
|
|
|
@mcp.tool()
|
|
def execute(
|
|
sql: str,
|
|
db: Optional[str] = None,
|
|
params_json: Optional[Union[str, list]] = None,
|
|
confirmed: bool = False,
|
|
) -> dict:
|
|
"""Run a write/DDL statement (INSERT/UPDATE/DELETE/CREATE/ALTER/DROP/...).
|
|
REQUIRES confirmed=True — first call `preview_execute` and obtain explicit user approval.
|
|
"""
|
|
if not confirmed:
|
|
return {
|
|
"status": "aborted",
|
|
"reason": "confirmed=False. Call preview_execute first, present the impact to the user, then re-call with confirmed=True.",
|
|
}
|
|
try:
|
|
params = json.loads(params_json) if isinstance(params_json, str) else (params_json or [])
|
|
conn = get_conn(db)
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, params)
|
|
rowcount = cur.rowcount
|
|
log(f"execute: db={db or PG_DEFAULT_DB} rowcount={rowcount}")
|
|
return {
|
|
"status": "ok",
|
|
"db": db or PG_DEFAULT_DB,
|
|
"rowcount": rowcount,
|
|
"sql": sql,
|
|
}
|
|
except Exception as e:
|
|
log(f"execute error: {traceback.format_exc()}")
|
|
return {"status": "error", "error": str(e), "sql": sql}
|
|
|
|
|
|
@mcp.tool()
|
|
def list_extensions(db: Optional[str] = None) -> dict:
|
|
"""List installed extensions in the database."""
|
|
try:
|
|
conn = get_conn(db)
|
|
with conn.cursor(row_factory=dict_row) as cur:
|
|
cur.execute("select extname, extversion from pg_extension order by extname")
|
|
rows = cur.fetchall()
|
|
return {"db": db or PG_DEFAULT_DB, "count": len(rows), "extensions": rows}
|
|
except Exception:
|
|
log(f"list_extensions error: {traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
if __name__ == "__main__":
|
|
log("MCP PostgreSQL server started (FastMCP)")
|
|
mcp.run()
|