Rework
This commit is contained in:
parent
ce5abcc217
commit
cfe4a641a6
40 changed files with 1728 additions and 1445 deletions
0
src/geoguessr_mcp/__init__.py
Normal file
0
src/geoguessr_mcp/__init__.py
Normal file
0
src/geoguessr_mcp/api/__init__.py
Normal file
0
src/geoguessr_mcp/api/__init__.py
Normal file
201
src/geoguessr_mcp/api/client.py
Normal file
201
src/geoguessr_mcp/api/client.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
"""
|
||||
HTTP client for Geoguessr API communication.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import logging
|
||||
from typing import Optional
|
||||
from ..auth.session import SessionManager
|
||||
from .endpoints import EndpointBuilder, get_endpoint_info
|
||||
|
||||
from ..config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GeoguessrClient:
|
||||
"""
|
||||
Wrapper for Geoguessr API HTTP communication.
|
||||
|
||||
This client automatically handles:
|
||||
- Authentication via session manager
|
||||
- Endpoint routing (main API vs. game server)
|
||||
- Error handling and retries
|
||||
- Logging and debugging
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_manager: SessionManager,
|
||||
base_url: str = settings.GEOGUESSR_API_URL,
|
||||
game_server_url: str = settings.GAME_SERVER_URL,
|
||||
timeout: float = 30.0
|
||||
):
|
||||
"""
|
||||
Initialize the Geoguessr API client.
|
||||
|
||||
Args:
|
||||
session_manager: Session manager for authentication
|
||||
base_url: Base URL for Geoguessr API
|
||||
game_server_url: URL for game server API
|
||||
timeout: Request timeout in seconds
|
||||
"""
|
||||
self.session_manager = session_manager
|
||||
self.base_url = base_url
|
||||
self.game_server_url = game_server_url
|
||||
self.timeout = timeout
|
||||
|
||||
async def get_authenticated_client(
|
||||
self,
|
||||
session_token: Optional[str] = None
|
||||
) -> httpx.AsyncClient:
|
||||
"""
|
||||
Get an authenticated async HTTP client.
|
||||
|
||||
Args:
|
||||
session_token: Optional session token for authentication
|
||||
|
||||
Returns:
|
||||
Authenticated httpx.AsyncClient
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid session is available
|
||||
"""
|
||||
session = await self.session_manager.get_session(session_token)
|
||||
if not session:
|
||||
raise ValueError(
|
||||
"No valid session available. Please:\n"
|
||||
"1. Use login() to authenticate, or\n"
|
||||
"2. Set GEOGUESSR_NCFA_COOKIE environment variable"
|
||||
)
|
||||
|
||||
client = httpx.AsyncClient(timeout=self.timeout)
|
||||
client.cookies.set(
|
||||
"_ncfa",
|
||||
session.ncfa_cookie,
|
||||
domain="www.geoguessr.com"
|
||||
)
|
||||
return client
|
||||
|
||||
def _get_base_url(self, endpoint: str, use_game_server: Optional[bool] = None) -> str:
|
||||
"""
|
||||
Determine the correct base URL for an endpoint.
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint
|
||||
use_game_server: Explicitly set game server usage, or auto-detect
|
||||
|
||||
Returns:
|
||||
Appropriate base URL
|
||||
"""
|
||||
if use_game_server is None:
|
||||
# Auto-detect based on endpoint
|
||||
use_game_server = EndpointBuilder.is_game_server_endpoint(endpoint)
|
||||
|
||||
return self.game_server_url if use_game_server else self.base_url
|
||||
|
||||
async def get(
|
||||
self,
|
||||
endpoint: str,
|
||||
session_token: Optional[str] = None,
|
||||
use_game_server: Optional[bool] = None,
|
||||
params: Optional[dict] = None,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Make a GET request to the Geoguessr API.
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint (e.g., "/v3/profiles")
|
||||
session_token: Optional session token
|
||||
use_game_server: Whether to use game server URL (auto-detected if None)
|
||||
params: Query parameters
|
||||
**kwargs: Additional arguments to pass to httpx.get
|
||||
|
||||
Returns:
|
||||
httpx.Response
|
||||
|
||||
Raises:
|
||||
httpx.HTTPError: On HTTP errors
|
||||
"""
|
||||
base = self._get_base_url(endpoint, use_game_server)
|
||||
url = f"{base}{endpoint}"
|
||||
|
||||
# Get endpoint metadata for logging
|
||||
metadata = get_endpoint_info(endpoint)
|
||||
logger.debug(
|
||||
f"GET {url} - {metadata.get('description', 'Unknown endpoint')}"
|
||||
)
|
||||
|
||||
async with await self.get_authenticated_client(session_token) as client:
|
||||
response = await client.get(url, params=params, **kwargs)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"GET {url} - Success ({response.status_code})")
|
||||
return response
|
||||
|
||||
async def post(
|
||||
self,
|
||||
endpoint: str,
|
||||
session_token: Optional[str] = None,
|
||||
use_game_server: Optional[bool] = None,
|
||||
json_data: Optional[dict] = None,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Make a POST request to the Geoguessr API.
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint
|
||||
session_token: Optional session token
|
||||
use_game_server: Whether to use game server URL (auto-detected if None)
|
||||
json_data: JSON data to send
|
||||
**kwargs: Additional arguments to pass to httpx.post
|
||||
|
||||
Returns:
|
||||
httpx.Response
|
||||
|
||||
Raises:
|
||||
httpx.HTTPError: On HTTP errors
|
||||
"""
|
||||
base = self._get_base_url(endpoint, use_game_server)
|
||||
url = f"{base}{endpoint}"
|
||||
|
||||
metadata = get_endpoint_info(endpoint)
|
||||
logger.debug(
|
||||
f"POST {url} - {metadata.get('description', 'Unknown endpoint')}"
|
||||
)
|
||||
|
||||
async with await self.get_authenticated_client(session_token) as client:
|
||||
response = await client.post(url, json=json_data, **kwargs)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"POST {url} - Success ({response.status_code})")
|
||||
return response
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
session_token: Optional[str] = None,
|
||||
use_game_server: Optional[bool] = None,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Make a generic HTTP request.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
endpoint: API endpoint
|
||||
session_token: Optional session token
|
||||
use_game_server: Whether to use game server URL
|
||||
**kwargs: Additional arguments to pass to httpx
|
||||
|
||||
Returns:
|
||||
httpx.Response
|
||||
"""
|
||||
base = self._get_base_url(endpoint, use_game_server)
|
||||
url = f"{base}{endpoint}"
|
||||
|
||||
async with await self.get_authenticated_client(session_token) as client:
|
||||
response = await client.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
368
src/geoguessr_mcp/api/endpoints.py
Normal file
368
src/geoguessr_mcp/api/endpoints.py
Normal file
|
|
@ -0,0 +1,368 @@
|
|||
"""
|
||||
Geoguessr API Endpoints
|
||||
Centralized endpoint definitions extracted from the Geoguessr API.
|
||||
|
||||
"""
|
||||
from ..config import settings
|
||||
|
||||
|
||||
class Endpoints:
|
||||
"""
|
||||
Centralized endpoint registry for Geoguessr API.
|
||||
|
||||
Usage:
|
||||
url = Endpoints.PROFILES.GET_PROFILE
|
||||
full_url = f"{GEOGUESSR_BASE_URL}{url}"
|
||||
"""
|
||||
|
||||
# ============================================================================
|
||||
# AUTHENTICATION ENDPOINTS
|
||||
# ============================================================================
|
||||
class AUTH:
|
||||
"""Authentication endpoints."""
|
||||
SIGNIN = "/v3/accounts/signin" # POST
|
||||
|
||||
# ============================================================================
|
||||
# PROFILE ENDPOINTS
|
||||
# ============================================================================
|
||||
class PROFILES:
|
||||
"""User profile and stats endpoints."""
|
||||
GET_PROFILE = "/v3/profiles" # GET - Get current user profile
|
||||
GET_STATS = "/v3/profiles/stats" # GET - Get user statistics
|
||||
GET_EXTENDED_STATS = "/v4/stats/me" # GET - Get extended statistics
|
||||
GET_ACHIEVEMENTS = "/v3/profiles/achievements" # GET - Get user achievements
|
||||
GET_USER_MAPS = "/v3/profiles/maps" # GET - Get user's custom maps
|
||||
|
||||
@staticmethod
|
||||
def get_public_profile(user_id: str) -> str:
|
||||
"""Get public profile by user ID."""
|
||||
return f"/v3/profiles/{user_id}"
|
||||
|
||||
@staticmethod
|
||||
def get_user_activities(user_id: str) -> str:
|
||||
"""Get user activities/feed."""
|
||||
return f"/v3/users/{user_id}/activities"
|
||||
|
||||
# ============================================================================
|
||||
# GAME ENDPOINTS
|
||||
# ============================================================================
|
||||
class GAMES:
|
||||
"""Game-related endpoints."""
|
||||
GET_UNFINISHED_GAMES = "/v3/social/events/unfinishedgames" # GET
|
||||
|
||||
@staticmethod
|
||||
def get_game_details(game_token: str) -> str:
|
||||
"""Get details for a specific game."""
|
||||
return f"/v3/games/{game_token}"
|
||||
|
||||
@staticmethod
|
||||
def get_streak_game(game_token: str) -> str:
|
||||
"""Get streak game details."""
|
||||
return f"/v3/games/streak/{game_token}"
|
||||
|
||||
# ============================================================================
|
||||
# GAME SERVER ENDPOINTS (Different base URL)
|
||||
# ============================================================================
|
||||
class GAME_SERVER:
|
||||
"""Game server endpoints (use GAME_SERVER_URL as base)."""
|
||||
GET_TOURNAMENTS = "/tournaments" # GET
|
||||
|
||||
@staticmethod
|
||||
def get_battle_royale(game_id: str) -> str:
|
||||
"""Get battle royale game."""
|
||||
return f"/battle-royale/{game_id}"
|
||||
|
||||
@staticmethod
|
||||
def get_duel(duel_id: str) -> str:
|
||||
"""Get duel details."""
|
||||
return f"/duels/{duel_id}"
|
||||
|
||||
@staticmethod
|
||||
def get_lobby(game_id: str) -> str:
|
||||
"""Get lobby information."""
|
||||
return f"/lobby/{game_id}"
|
||||
|
||||
# ============================================================================
|
||||
# COMPETITIVE/SEASONS ENDPOINTS
|
||||
# ============================================================================
|
||||
class COMPETITIVE:
|
||||
"""Competitive and season-related endpoints."""
|
||||
GET_ACTIVE_SEASON_STATS = "/v4/seasons/active/stats" # GET
|
||||
|
||||
@staticmethod
|
||||
def get_season_game(game_mode: str) -> str:
|
||||
"""Get season game for specific mode."""
|
||||
return f"/v4/seasons/game/{game_mode}"
|
||||
|
||||
# ============================================================================
|
||||
# CHALLENGE ENDPOINTS
|
||||
# ============================================================================
|
||||
class CHALLENGES:
|
||||
"""Challenge-related endpoints."""
|
||||
|
||||
@staticmethod
|
||||
def get_daily_challenge(endpoint: str = "current") -> str:
|
||||
"""
|
||||
Get daily challenge.
|
||||
|
||||
Args:
|
||||
endpoint: 'current', 'today', or specific date
|
||||
"""
|
||||
return f"/v3/challenges/daily-challenges/{endpoint}"
|
||||
|
||||
@staticmethod
|
||||
def get_challenge(challenge_token: str) -> str:
|
||||
"""Get challenge details."""
|
||||
return f"/v3/challenges/{challenge_token}"
|
||||
|
||||
# ============================================================================
|
||||
# SOCIAL/FRIENDS ENDPOINTS
|
||||
# ============================================================================
|
||||
class SOCIAL:
|
||||
"""Social and friends endpoints."""
|
||||
GET_FRIENDS_SUMMARY = "/v3/social/friends/summary" # GET
|
||||
GET_UNCLAIMED_BADGES = "/v3/social/badges/unclaimed" # GET
|
||||
GET_PERSONALIZED_MAPS = "/v3/social/maps/browse/personalized" # GET
|
||||
|
||||
@staticmethod
|
||||
def get_activity_feed(count: int = 10, page: int = 0) -> tuple[str, dict]:
|
||||
"""
|
||||
Get user activity feed.
|
||||
|
||||
Returns:
|
||||
Tuple of (endpoint, params_dict)
|
||||
"""
|
||||
return "/v4/feed/private", {"count": count, "page": page}
|
||||
|
||||
@staticmethod
|
||||
def get_friends_activities(time_frame: str, limit: int = 20) -> tuple[str, dict]:
|
||||
"""
|
||||
Get friends' activities.
|
||||
|
||||
Args:
|
||||
time_frame: Time frame for activities
|
||||
limit: Maximum number of activities
|
||||
|
||||
Returns:
|
||||
Tuple of (endpoint, params_dict)
|
||||
"""
|
||||
return "/v3/social/friends/activities", {"timeFrame": time_frame, "limit": limit}
|
||||
|
||||
# ============================================================================
|
||||
# MAPS ENDPOINTS
|
||||
# ============================================================================
|
||||
class MAPS:
|
||||
"""Map-related endpoints."""
|
||||
GET_PERSONALIZED_MAPS = "/v3/social/maps/browse/personalized" # GET
|
||||
|
||||
@staticmethod
|
||||
def get_map_details(map_id: str) -> str:
|
||||
"""Get map details."""
|
||||
return f"/maps/{map_id}"
|
||||
|
||||
@staticmethod
|
||||
def get_map_leaderboard(map_id: str) -> str:
|
||||
"""Get leaderboard for a map."""
|
||||
return f"/v3/scores/maps/{map_id}"
|
||||
|
||||
@staticmethod
|
||||
def search_maps(search_type: str, query: str, count: int = 20, page: int = 0) -> tuple[str, dict]:
|
||||
"""
|
||||
Search for maps.
|
||||
|
||||
Args:
|
||||
search_type: Type of search ('all', 'official', 'community', etc.)
|
||||
query: Search query
|
||||
count: Number of results per-page
|
||||
page: Page number
|
||||
|
||||
Returns:
|
||||
Tuple of (endpoint, params_dict)
|
||||
"""
|
||||
return f"/v3/social/maps/browse/{search_type}", {
|
||||
"q": query,
|
||||
"count": count,
|
||||
"page": page
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# EXPLORER MODE ENDPOINTS
|
||||
# ============================================================================
|
||||
class EXPLORER:
|
||||
"""Explorer mode endpoints."""
|
||||
GET_PROGRESS = "/v3/explorer" # GET - Get explorer mode progress
|
||||
|
||||
# ============================================================================
|
||||
# OBJECTIVES/REWARDS ENDPOINTS
|
||||
# ============================================================================
|
||||
class OBJECTIVES:
|
||||
"""Objectives and rewards endpoints."""
|
||||
GET_OBJECTIVES = "/v4/objectives" # GET - Get current objectives
|
||||
GET_UNCLAIMED_OBJECTIVES = "/v4/objectives/unclaimed" # GET - Get unclaimed rewards
|
||||
|
||||
# ============================================================================
|
||||
# SUBSCRIPTION ENDPOINTS
|
||||
# ============================================================================
|
||||
class SUBSCRIPTION:
|
||||
"""Subscription-related endpoints."""
|
||||
GET_SUBSCRIPTION_INFO = "/v3/subscriptions" # GET - Get subscription details
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ENDPOINT UTILITIES
|
||||
# ============================================================================
|
||||
|
||||
class EndpointBuilder:
|
||||
"""Utility class for building complete URLs."""
|
||||
|
||||
@staticmethod
|
||||
def build_url(endpoint: str, use_game_server: bool = False) -> str:
|
||||
"""
|
||||
Build complete URL for an endpoint.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint path
|
||||
use_game_server: Whether to use game server URL
|
||||
|
||||
Returns:
|
||||
Complete URL
|
||||
"""
|
||||
base = settings.GAME_SERVER_URL if use_game_server else settings.GEOGUESSR_BASE_URL
|
||||
return f"{base}{endpoint}"
|
||||
|
||||
@staticmethod
|
||||
def is_game_server_endpoint(endpoint: str) -> bool:
|
||||
"""
|
||||
Check if endpoint belongs to game server.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint path
|
||||
|
||||
Returns:
|
||||
True if it's a game server endpoint
|
||||
"""
|
||||
game_server_prefixes = [
|
||||
"/battle-royale/",
|
||||
"/duels/",
|
||||
"/lobby/",
|
||||
"/tournaments"
|
||||
]
|
||||
return any(endpoint.startswith(prefix) for prefix in game_server_prefixes)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ENDPOINT METADATA
|
||||
# ============================================================================
|
||||
|
||||
ENDPOINT_METADATA = {
|
||||
# Profile endpoints
|
||||
"/v3/profiles": {
|
||||
"method": "GET",
|
||||
"description": "Get current user profile",
|
||||
"auth_required": True,
|
||||
"response_type": "profile"
|
||||
},
|
||||
"/v3/profiles/stats": {
|
||||
"method": "GET",
|
||||
"description": "Get user statistics",
|
||||
"auth_required": True,
|
||||
"response_type": "stats"
|
||||
},
|
||||
"/v4/stats/me": {
|
||||
"method": "GET",
|
||||
"description": "Get extended statistics",
|
||||
"auth_required": True,
|
||||
"response_type": "extended_stats"
|
||||
},
|
||||
"/v3/profiles/achievements": {
|
||||
"method": "GET",
|
||||
"description": "Get user achievements",
|
||||
"auth_required": True,
|
||||
"response_type": "achievements"
|
||||
},
|
||||
|
||||
# Game endpoints
|
||||
"/v3/games/{game_token}": {
|
||||
"method": "GET",
|
||||
"description": "Get game details",
|
||||
"auth_required": True,
|
||||
"response_type": "game"
|
||||
},
|
||||
"/v3/social/events/unfinishedgames": {
|
||||
"method": "GET",
|
||||
"description": "Get unfinished games",
|
||||
"auth_required": True,
|
||||
"response_type": "games_list"
|
||||
},
|
||||
|
||||
# Competitive endpoints
|
||||
"/v4/seasons/active/stats": {
|
||||
"method": "GET",
|
||||
"description": "Get active season statistics",
|
||||
"auth_required": True,
|
||||
"response_type": "season_stats"
|
||||
},
|
||||
|
||||
# Social endpoints
|
||||
"/v4/feed/private": {
|
||||
"method": "GET",
|
||||
"description": "Get private activity feed",
|
||||
"auth_required": True,
|
||||
"response_type": "feed",
|
||||
"params": ["count", "page"]
|
||||
},
|
||||
"/v3/social/friends/summary": {
|
||||
"method": "GET",
|
||||
"description": "Get friends summary",
|
||||
"auth_required": True,
|
||||
"response_type": "friends"
|
||||
},
|
||||
|
||||
# Maps endpoints
|
||||
"/maps/{map_id}": {
|
||||
"method": "GET",
|
||||
"description": "Get map details",
|
||||
"auth_required": False,
|
||||
"response_type": "map"
|
||||
},
|
||||
"/v3/scores/maps/{map_id}": {
|
||||
"method": "GET",
|
||||
"description": "Get map leaderboard",
|
||||
"auth_required": True,
|
||||
"response_type": "leaderboard"
|
||||
},
|
||||
|
||||
# Explorer endpoints
|
||||
"/v3/explorer": {
|
||||
"method": "GET",
|
||||
"description": "Get explorer mode progress",
|
||||
"auth_required": True,
|
||||
"response_type": "explorer"
|
||||
},
|
||||
|
||||
# Objectives endpoints
|
||||
"/v4/objectives": {
|
||||
"method": "GET",
|
||||
"description": "Get current objectives",
|
||||
"auth_required": True,
|
||||
"response_type": "objectives"
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_endpoint_info(endpoint: str) -> dict:
|
||||
"""
|
||||
Get metadata for an endpoint.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint path
|
||||
|
||||
Returns:
|
||||
Dictionary with endpoint metadata
|
||||
"""
|
||||
return ENDPOINT_METADATA.get(endpoint, {
|
||||
"method": "GET",
|
||||
"description": "Unknown endpoint",
|
||||
"auth_required": True,
|
||||
"response_type": "unknown"
|
||||
})
|
||||
0
src/geoguessr_mcp/auth/__init__.py
Normal file
0
src/geoguessr_mcp/auth/__init__.py
Normal file
0
src/geoguessr_mcp/auth/middleware.py
Normal file
0
src/geoguessr_mcp/auth/middleware.py
Normal file
201
src/geoguessr_mcp/auth/session.py
Normal file
201
src/geoguessr_mcp/auth/session.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
"""
|
||||
Session management for Geoguessr authentication.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserSession:
|
||||
"""Represents an authenticated Geoguessr session."""
|
||||
|
||||
ncfa_cookie: str
|
||||
user_id: str
|
||||
username: str
|
||||
email: str
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the session is still valid."""
|
||||
if self.expires_at and datetime.now(UTC) > self.expires_at:
|
||||
return False
|
||||
return bool(self.ncfa_cookie)
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manages user sessions for the MCP server."""
|
||||
|
||||
def __init__(self, default_cookie: Optional[str] = None):
|
||||
self._sessions: dict[str, UserSession] = {}
|
||||
self._user_sessions: dict[str, str] = {}
|
||||
self._default_cookie: Optional[str] = default_cookie
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _generate_session_token() -> str:
|
||||
"""Generate a secure session token."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
async def login(
|
||||
self, email: str, password: str, base_url: str = "https://www.geoguessr.com/api"
|
||||
) -> tuple[str, UserSession]:
|
||||
"""
|
||||
Authenticate with Geoguessr and create a session.
|
||||
|
||||
Args:
|
||||
email: User's email address
|
||||
password: User's password
|
||||
base_url: Geoguessr API base URL
|
||||
|
||||
Returns:
|
||||
tuple[str, UserSession]: (session_token, UserSession) on success
|
||||
|
||||
Raises:
|
||||
ValueError: On authentication failure
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
# Attempt to sign in
|
||||
response = await client.post(
|
||||
f"{base_url}/v3/accounts/signin",
|
||||
json={"email": email, "password": password},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise ValueError("Invalid email or password")
|
||||
elif response.status_code == 403:
|
||||
raise ValueError("Account access denied")
|
||||
elif response.status_code == 429:
|
||||
raise ValueError("Too many login attempts")
|
||||
elif response.status_code != 200:
|
||||
raise ValueError(f"Login failed: {response.status_code}")
|
||||
|
||||
# Extract the _ncfa cookie
|
||||
ncfa_cookie = self._extract_ncfa_cookie(response)
|
||||
if not ncfa_cookie:
|
||||
raise ValueError("No session cookie received")
|
||||
|
||||
# Get user profile
|
||||
client.cookies.set("_ncfa", ncfa_cookie, domain="www.geoguessr.com")
|
||||
profile_response = await client.get(f"{base_url}/v3/profiles")
|
||||
|
||||
if profile_response.status_code != 200:
|
||||
raise ValueError("Failed to retrieve user profile")
|
||||
|
||||
profile = profile_response.json()
|
||||
|
||||
# Create and store session
|
||||
session = UserSession(
|
||||
ncfa_cookie=ncfa_cookie,
|
||||
user_id=profile.get("id", ""),
|
||||
username=profile.get("nick", ""),
|
||||
email=email,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=30),
|
||||
)
|
||||
|
||||
session_token = await self._store_session(session)
|
||||
logger.info(f"User {session.username} logged in successfully")
|
||||
|
||||
return session_token, session
|
||||
|
||||
@staticmethod
|
||||
def _extract_ncfa_cookie(response: httpx.Response) -> Optional[str]:
|
||||
"""Extract _ncfa cookie from response."""
|
||||
# Try cookies jar first
|
||||
for cookie in response.cookies.jar:
|
||||
if cookie.name == "_ncfa":
|
||||
return cookie.value
|
||||
|
||||
# Try Set-Cookie header
|
||||
set_cookie = response.headers.get("set-cookie", "")
|
||||
if "_ncfa=" in set_cookie:
|
||||
for part in set_cookie.split(";"):
|
||||
if part.strip().startswith("_ncfa="):
|
||||
return part.strip()[6:]
|
||||
|
||||
return None
|
||||
|
||||
async def _store_session(self, session: UserSession) -> str:
|
||||
"""Store a session and return its token."""
|
||||
async with self._lock:
|
||||
session_token = self._generate_session_token()
|
||||
|
||||
# Remove old session for this user if exists
|
||||
if session.user_id in self._user_sessions:
|
||||
old_token = self._user_sessions[session.user_id]
|
||||
self._sessions.pop(old_token, None)
|
||||
|
||||
self._sessions[session_token] = session
|
||||
self._user_sessions[session.user_id] = session_token
|
||||
|
||||
return session_token
|
||||
|
||||
async def logout(self, session_token: str) -> bool:
|
||||
"""
|
||||
Logout and invalidate a session.
|
||||
|
||||
Args:
|
||||
session_token: Token of the session to logout
|
||||
|
||||
Returns:
|
||||
bool: True if session was found and removed, False otherwise
|
||||
"""
|
||||
async with self._lock:
|
||||
if session_token in self._sessions:
|
||||
session = self._sessions.pop(session_token)
|
||||
self._user_sessions.pop(session.user_id, None)
|
||||
logger.info(f"User {session.username} logged out")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_session(self, session_token: Optional[str] = None) -> Optional[UserSession]:
|
||||
"""
|
||||
Get a session by token or return default if available.
|
||||
|
||||
Args:
|
||||
session_token: Optional session token to look up
|
||||
|
||||
Returns:
|
||||
UserSession if found and valid, None otherwise
|
||||
"""
|
||||
if session_token:
|
||||
async with self._lock:
|
||||
session = self._sessions.get(session_token)
|
||||
if session and session.is_valid():
|
||||
return session
|
||||
elif session:
|
||||
# Session expired, clean up
|
||||
self._sessions.pop(session_token, None)
|
||||
self._user_sessions.pop(session.user_id, None)
|
||||
|
||||
# Fall back to default cookie if available
|
||||
if self._default_cookie:
|
||||
return UserSession(
|
||||
ncfa_cookie=self._default_cookie,
|
||||
user_id="default",
|
||||
username="default",
|
||||
email="default",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def set_default_cookie(self, cookie: str) -> None:
|
||||
"""
|
||||
Set or update the default NCFA cookie.
|
||||
|
||||
Args:
|
||||
cookie: The NCFA cookie value to set as default
|
||||
"""
|
||||
async with self._lock:
|
||||
self._default_cookie = cookie
|
||||
logger.info("Default NCFA cookie updated")
|
||||
17
src/geoguessr_mcp/config.py
Normal file
17
src/geoguessr_mcp/config.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Configuration management."""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Settings:
|
||||
HOST: str = os.getenv("MCP_HOST", "0.0.0.0")
|
||||
PORT: int = int(os.getenv("MCP_PORT", "8000"))
|
||||
TRANSPORT: str = os.getenv("MCP_TRANSPORT", "streamable-http")
|
||||
GEOGUESSR_BASE_URL: str = "https://www.geoguessr.com/api"
|
||||
GAME_SERVER_URL: str = "https://game-server.geoguessr.com/api"
|
||||
DEFAULT_NCFA_COOKIE: str | None = os.getenv("GEOGUESSR_NCFA_COOKIE")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
19
src/geoguessr_mcp/main.py
Normal file
19
src/geoguessr_mcp/main.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
"""Main entry point for the Geoguessr MCP Server."""
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from .config import settings
|
||||
from .tools import register_all_tools
|
||||
|
||||
mcp = FastMCP(
|
||||
"Geoguessr Analyzer",
|
||||
instructions="MCP server for analyzing Geoguessr game statistics",
|
||||
host=settings.HOST,
|
||||
port=settings.PORT,
|
||||
)
|
||||
|
||||
# Register all tools
|
||||
register_all_tools(mcp)
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run(transport=settings.TRANSPORT)
|
||||
0
src/geoguessr_mcp/models/__init__.py
Normal file
0
src/geoguessr_mcp/models/__init__.py
Normal file
44
src/geoguessr_mcp/models/game.py
Normal file
44
src/geoguessr_mcp/models/game.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""Game-related data models."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoundGuess:
|
||||
"""Represents a single round guess."""
|
||||
|
||||
score: int
|
||||
distance_meters: int
|
||||
time_seconds: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Game:
|
||||
"""Represents a complete game."""
|
||||
|
||||
token: str
|
||||
map_name: str
|
||||
mode: str
|
||||
total_score: int
|
||||
rounds: List[RoundGuess]
|
||||
|
||||
@classmethod
|
||||
def from_api_response(cls, data: dict) -> "Game":
|
||||
"""Create Game from API response."""
|
||||
rounds = [
|
||||
RoundGuess(
|
||||
score=r.get("roundScoreInPoints", 0),
|
||||
distance_meters=r.get("distanceInMeters", 0),
|
||||
time_seconds=r.get("time", 0),
|
||||
)
|
||||
for r in data.get("player", {}).get("guesses", [])
|
||||
]
|
||||
|
||||
return cls(
|
||||
token=data["token"],
|
||||
map_name=data.get("map", {}).get("name", "Unknown"),
|
||||
mode=data.get("type", "Unknown"),
|
||||
total_score=sum(r.score for r in rounds),
|
||||
rounds=rounds,
|
||||
)
|
||||
29
src/geoguessr_mcp/models/profile.py
Normal file
29
src/geoguessr_mcp/models/profile.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""Profile-related data models."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserProfile:
|
||||
"""User profile information."""
|
||||
|
||||
id: str
|
||||
nick: str
|
||||
email: str
|
||||
country: str
|
||||
level: int
|
||||
created: str
|
||||
is_verified: bool
|
||||
|
||||
@classmethod
|
||||
def from_api_response(cls, data: dict) -> "UserProfile":
|
||||
"""Create UserProfile from API response."""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
nick=data["nick"],
|
||||
email=data.get("email", ""),
|
||||
country=data.get("country", ""),
|
||||
level=data.get("level", 0),
|
||||
created=data.get("created", ""),
|
||||
is_verified=data.get("isVerified", False),
|
||||
)
|
||||
0
src/geoguessr_mcp/models/stats.py
Normal file
0
src/geoguessr_mcp/models/stats.py
Normal file
0
src/geoguessr_mcp/services/__init__.py
Normal file
0
src/geoguessr_mcp/services/__init__.py
Normal file
30
src/geoguessr_mcp/services/analysis_service.py
Normal file
30
src/geoguessr_mcp/services/analysis_service.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
"""Analysis and statistics calculations."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from ..models.game import Game
|
||||
|
||||
|
||||
class AnalysisService:
|
||||
"""Service for analyzing game data."""
|
||||
|
||||
@staticmethod
|
||||
def calculate_statistics(games: List[Game]) -> dict:
|
||||
"""Calculate aggregate statistics from games."""
|
||||
if not games:
|
||||
return {"games_analyzed": 0, "total_score": 0, "average_score": 0, "perfect_rounds": 0}
|
||||
|
||||
total_score = sum(g.total_score for g in games)
|
||||
total_rounds = sum(len(g.rounds) for g in games)
|
||||
perfect_rounds = sum(1 for g in games for r in g.rounds if r.score == 5000)
|
||||
|
||||
return {
|
||||
"games_analyzed": len(games),
|
||||
"total_score": total_score,
|
||||
"average_score": total_score / len(games),
|
||||
"total_rounds": total_rounds,
|
||||
"perfect_rounds": perfect_rounds,
|
||||
"perfect_round_percentage": (
|
||||
(perfect_rounds / total_rounds * 100) if total_rounds > 0 else 0
|
||||
),
|
||||
}
|
||||
0
src/geoguessr_mcp/services/competitive_service.py
Normal file
0
src/geoguessr_mcp/services/competitive_service.py
Normal file
0
src/geoguessr_mcp/services/game_service.py
Normal file
0
src/geoguessr_mcp/services/game_service.py
Normal file
136
src/geoguessr_mcp/services/profile_service.py
Normal file
136
src/geoguessr_mcp/services/profile_service.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
"""
|
||||
Profile-related business logic.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from ..api.client import GeoguessrClient
|
||||
from ..api.endpoints import Endpoints
|
||||
from ..models.profile import UserProfile, UserStats
|
||||
|
||||
|
||||
class ProfileService:
|
||||
"""Service for profile operations."""
|
||||
|
||||
def __init__(self, client: GeoguessrClient):
|
||||
"""
|
||||
Initialize the profile service.
|
||||
|
||||
Args:
|
||||
client: GeoGuessr API client
|
||||
"""
|
||||
self.client = client
|
||||
|
||||
async def get_profile(
|
||||
self,
|
||||
session_token: Optional[str] = None
|
||||
) -> UserProfile:
|
||||
"""
|
||||
Get user profile.
|
||||
|
||||
Args:
|
||||
session_token: Optional session token for authentication
|
||||
|
||||
Returns:
|
||||
UserProfile with user information
|
||||
|
||||
Raises:
|
||||
httpx.HTTPError: If the API request fails
|
||||
"""
|
||||
response = await self.client.get(
|
||||
Endpoints.PROFILES.GET_PROFILE,
|
||||
session_token
|
||||
)
|
||||
data = response.json()
|
||||
return UserProfile.from_api_response(data)
|
||||
|
||||
async def get_stats(
|
||||
self,
|
||||
session_token: Optional[str] = None
|
||||
) -> UserStats:
|
||||
"""
|
||||
Get user statistics.
|
||||
|
||||
Args:
|
||||
session_token: Optional session token for authentication
|
||||
|
||||
Returns:
|
||||
UserStats with user statistics
|
||||
|
||||
Raises:
|
||||
httpx.HTTPError: If the API request fails
|
||||
"""
|
||||
response = await self.client.get(
|
||||
Endpoints.PROFILES.GET_STATS,
|
||||
session_token
|
||||
)
|
||||
data = response.json()
|
||||
return UserStats.from_api_response(data)
|
||||
|
||||
async def get_extended_stats(
|
||||
self,
|
||||
session_token: Optional[str] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Get extended user statistics.
|
||||
|
||||
Args:
|
||||
session_token: Optional session token for authentication
|
||||
|
||||
Returns:
|
||||
Dictionary with extended statistics
|
||||
|
||||
Raises:
|
||||
httpx.HTTPError: If the API request fails
|
||||
"""
|
||||
response = await self.client.get(
|
||||
Endpoints.PROFILES.GET_EXTENDED_STATS,
|
||||
session_token
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def get_achievements(
|
||||
self,
|
||||
session_token: Optional[str] = None
|
||||
) -> list:
|
||||
"""
|
||||
Get user achievements.
|
||||
|
||||
Args:
|
||||
session_token: Optional session token for authentication
|
||||
|
||||
Returns:
|
||||
List of achievement dictionaries
|
||||
|
||||
Raises:
|
||||
httpx.HTTPError: If the API request fails
|
||||
"""
|
||||
response = await self.client.get(
|
||||
Endpoints.PROFILES.GET_ACHIEVEMENTS,
|
||||
session_token
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def get_public_profile(
|
||||
self,
|
||||
user_id: str,
|
||||
session_token: Optional[str] = None
|
||||
) -> UserProfile:
|
||||
"""
|
||||
Get public profile of another user.
|
||||
|
||||
Args:
|
||||
user_id: User ID to fetch
|
||||
session_token: Optional session token for authentication
|
||||
|
||||
Returns:
|
||||
UserProfile with public user information
|
||||
|
||||
Raises:
|
||||
httpx.HTTPError: If the API request fails
|
||||
"""
|
||||
response = await self.client.get(
|
||||
Endpoints.PROFILES.get_public_profile(user_id),
|
||||
session_token
|
||||
)
|
||||
data = response.json()
|
||||
return UserProfile.from_api_response(data)
|
||||
31
src/geoguessr_mcp/tools/__init__.py
Normal file
31
src/geoguessr_mcp/tools/__init__.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
"""Register all MCP tools."""
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from ..api.client import GeoguessrClient
|
||||
from ..auth.session import SessionManager
|
||||
from ..services.analysis_service import AnalysisService
|
||||
from ..services.game_service import GameService
|
||||
from ..services.profile_service import ProfileService
|
||||
from .analysis_tools import register_analysis_tools
|
||||
from .auth_tools import register_auth_tools
|
||||
from .game_tools import register_game_tools
|
||||
from .profile_tools import register_profile_tools
|
||||
|
||||
|
||||
def register_all_tools(mcp: FastMCP):
|
||||
"""Register all tools with the MCP server."""
|
||||
# Initialize dependencies
|
||||
session_manager = SessionManager()
|
||||
client = GeoguessrClient(session_manager)
|
||||
|
||||
# Initialize services
|
||||
profile_service = ProfileService(client)
|
||||
game_service = GameService(client)
|
||||
analysis_service = AnalysisService()
|
||||
|
||||
# Register tools
|
||||
register_auth_tools(mcp, session_manager)
|
||||
register_profile_tools(mcp, profile_service)
|
||||
register_game_tools(mcp, game_service)
|
||||
register_analysis_tools(mcp, analysis_service, game_service)
|
||||
124
src/geoguessr_mcp/tools/analysis_tools.py
Normal file
124
src/geoguessr_mcp/tools/analysis_tools.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
|
||||
@mcp.tool()
|
||||
async def analyze_recent_games(count: int = 10) -> dict:
|
||||
"""
|
||||
Analyze recent games and provide statistics summary.
|
||||
Fetches recent games from the activity feed and calculates aggregate statistics.
|
||||
|
||||
Args:
|
||||
count: Number of recent games to analyze (default: 10)
|
||||
"""
|
||||
async with await get_async_session() as client:
|
||||
# Get activity feed
|
||||
feed_response = await client.get(
|
||||
f"{GEOGUESSR_BASE_URL}/v4/feed/private",
|
||||
params={"count": count * 2, "page": 0}
|
||||
)
|
||||
feed_response.raise_for_status()
|
||||
feed = feed_response.json()
|
||||
|
||||
games_analyzed = []
|
||||
total_score = 0
|
||||
total_rounds = 0
|
||||
perfect_rounds = 0
|
||||
|
||||
for entry in feed.get("entries", []):
|
||||
if entry.get("type") == "PlayedGame" and len(games_analyzed) < count:
|
||||
game_token = entry.get("payload", {}).get("gameToken")
|
||||
if game_token:
|
||||
try:
|
||||
game_response = await client.get(f"{GEOGUESSR_BASE_URL}/v3/games/{game_token}")
|
||||
if game_response.status_code == 200:
|
||||
game = game_response.json()
|
||||
|
||||
game_info = {
|
||||
"token": game_token,
|
||||
"map": game.get("map", {}).get("name", "Unknown"),
|
||||
"mode": game.get("type", "Unknown"),
|
||||
"total_score": 0,
|
||||
"rounds": []
|
||||
}
|
||||
|
||||
for round_data in game.get("player", {}).get("guesses", []):
|
||||
round_score = round_data.get("roundScoreInPoints", 0)
|
||||
game_info["total_score"] += round_score
|
||||
game_info["rounds"].append({
|
||||
"score": round_score,
|
||||
"distance": round_data.get("distanceInMeters", 0),
|
||||
"time": round_data.get("time", 0)
|
||||
})
|
||||
|
||||
total_rounds += 1
|
||||
if round_score == 5000:
|
||||
perfect_rounds += 1
|
||||
|
||||
total_score += game_info["total_score"]
|
||||
games_analyzed.append(game_info)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch game {game_token}: {e}")
|
||||
|
||||
return {
|
||||
"games_analyzed": len(games_analyzed),
|
||||
"total_score": total_score,
|
||||
"average_score": total_score / len(games_analyzed) if games_analyzed else 0,
|
||||
"total_rounds": total_rounds,
|
||||
"perfect_rounds": perfect_rounds,
|
||||
"perfect_round_percentage": (perfect_rounds / total_rounds * 100) if total_rounds > 0 else 0,
|
||||
"games": games_analyzed
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_performance_summary() -> dict:
|
||||
"""
|
||||
Get a comprehensive performance summary combining profile stats,
|
||||
achievements, and season information.
|
||||
"""
|
||||
async with await get_async_session() as client:
|
||||
results = {}
|
||||
|
||||
# Get profile
|
||||
try:
|
||||
profile_response = await client.get(f"{GEOGUESSR_BASE_URL}/v3/profiles")
|
||||
profile_response.raise_for_status()
|
||||
results["profile"] = profile_response.json()
|
||||
except Exception as e:
|
||||
results["profile_error"] = str(e)
|
||||
|
||||
# Get stats
|
||||
try:
|
||||
stats_response = await client.get(f"{GEOGUESSR_BASE_URL}/v3/profiles/stats")
|
||||
stats_response.raise_for_status()
|
||||
results["stats"] = stats_response.json()
|
||||
except Exception as e:
|
||||
results["stats_error"] = str(e)
|
||||
|
||||
# Get extended stats
|
||||
try:
|
||||
extended_response = await client.get(f"{GEOGUESSR_BASE_URL}/v4/stats/me")
|
||||
extended_response.raise_for_status()
|
||||
results["extended_stats"] = extended_response.json()
|
||||
except Exception as e:
|
||||
results["extended_stats_error"] = str(e)
|
||||
|
||||
# Get season stats
|
||||
try:
|
||||
season_response = await client.get(f"{GEOGUESSR_BASE_URL}/v4/seasons/active/stats")
|
||||
season_response.raise_for_status()
|
||||
results["current_season"] = season_response.json()
|
||||
except Exception as e:
|
||||
results["season_error"] = str(e)
|
||||
|
||||
# Get achievements
|
||||
try:
|
||||
achievements_response = await client.get(f"{GEOGUESSR_BASE_URL}/v3/profiles/achievements")
|
||||
achievements_response.raise_for_status()
|
||||
achievements = achievements_response.json()
|
||||
results["achievements_summary"] = {
|
||||
"total": len(achievements) if isinstance(achievements, list) else 0,
|
||||
"achievements": achievements
|
||||
}
|
||||
except Exception as e:
|
||||
results["achievements_error"] = str(e)
|
||||
|
||||
return results
|
||||
182
src/geoguessr_mcp/tools/auth_tools.py
Normal file
182
src/geoguessr_mcp/tools/auth_tools.py
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
"""MCP tools for auth operations."""
|
||||
import logging
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from ..auth.session import SessionManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_auth_tools(mcp: FastMCP, session_manager: SessionManager):
|
||||
"""Register auth-related tools."""
|
||||
|
||||
@mcp.tool()
|
||||
async def login(email: str, password: str) -> dict:
|
||||
"""
|
||||
Authenticate with Geoguessr using your email and password.
|
||||
This creates a session that will be used for all later API calls.
|
||||
|
||||
Args:
|
||||
email: Your Geoguessr account email
|
||||
password: Your Geoguessr account password
|
||||
|
||||
Returns:
|
||||
Session information including username and session token
|
||||
|
||||
Note: Your credentials are only used to get an authentication token
|
||||
from Geoguessr. They are not stored on the server.
|
||||
"""
|
||||
|
||||
try:
|
||||
session_token, session = await session_manager.login(email, password)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Successfully logged in as {session.username}",
|
||||
"username": session.username,
|
||||
"user_id": session.user_id,
|
||||
"session_token": session_token,
|
||||
"expires_at": session.expires_at.isoformat() if session.expires_at else None,
|
||||
}
|
||||
except ValueError as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
except Exception as e:
|
||||
logger.error(f"Login error: {e}")
|
||||
return {"success": False, "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def logout() -> dict:
|
||||
"""
|
||||
Logout from the current Geoguessr session.
|
||||
This invalidates the current session token.
|
||||
"""
|
||||
global _current_session_token
|
||||
|
||||
if _current_session_token:
|
||||
success = await session_manager.logout(_current_session_token)
|
||||
_current_session_token = None
|
||||
return {
|
||||
"success": success,
|
||||
"message": "Successfully logged out" if success else "No active session to logout",
|
||||
}
|
||||
|
||||
return {"success": False, "message": "No active session"}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def set_session_token(token: str) -> dict:
|
||||
"""
|
||||
Set an existing session token for authentication.
|
||||
Use this if you have a previously obtained session token.
|
||||
|
||||
Args:
|
||||
token: A valid session token from a previous login
|
||||
"""
|
||||
global _current_session_token
|
||||
|
||||
session = await session_manager.get_session(token)
|
||||
if session and session.is_valid():
|
||||
_current_session_token = token
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Session set for user {session.username}",
|
||||
"username": session.username,
|
||||
}
|
||||
|
||||
return {"success": False, "error": "Invalid or expired session token"}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def set_ncfa_cookie(cookie: str) -> dict:
|
||||
"""
|
||||
Directly set the _ncfa cookie for authentication.
|
||||
Use this if you've manually extracted the cookie from your browser.
|
||||
|
||||
Args:
|
||||
cookie: The _ncfa cookie value from your browser
|
||||
|
||||
Note: This sets the cookie as the default for all requests.
|
||||
"""
|
||||
global _current_session_token
|
||||
|
||||
# Validate the cookie by making a test request
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
client.cookies.set("_ncfa", cookie, domain="www.geoguessr.com")
|
||||
response = await client.get(f"{GEOGUESSR_BASE_URL}/v3/profiles")
|
||||
|
||||
if response.status_code != 200:
|
||||
return {"success": False, "error": "Invalid cookie - authentication failed"}
|
||||
|
||||
profile = response.json()
|
||||
|
||||
# Create a session from the cookie
|
||||
session = UserSession(
|
||||
ncfa_cookie=cookie,
|
||||
user_id=profile.get("id", ""),
|
||||
username=profile.get("nick", ""),
|
||||
email="manual@cookie",
|
||||
expires_at=datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=30),
|
||||
)
|
||||
|
||||
# Store as a session
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
async with session_manager._lock:
|
||||
session_manager._sessions[session_token] = session
|
||||
session_manager._user_sessions[session.user_id] = session_token
|
||||
|
||||
_current_session_token = session_token
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Cookie set successfully. Authenticated as {session.username}",
|
||||
"username": session.username,
|
||||
"user_id": session.user_id,
|
||||
"session_token": session_token,
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_auth_status() -> dict:
|
||||
"""
|
||||
Check the current authentication status.
|
||||
Returns information about the current session or authentication method.
|
||||
"""
|
||||
global _current_session_token
|
||||
|
||||
# Check for active session
|
||||
if _current_session_token:
|
||||
session = await session_manager.get_session(_current_session_token)
|
||||
if session and session.is_valid():
|
||||
return {
|
||||
"authenticated": True,
|
||||
"method": "session",
|
||||
"username": session.username,
|
||||
"user_id": session.user_id,
|
||||
"expires_at": session.expires_at.isoformat() if session.expires_at else None,
|
||||
}
|
||||
|
||||
# Check for environment variable
|
||||
env_cookie = os.environ.get("GEOGUESSR_NCFA_COOKIE")
|
||||
if env_cookie:
|
||||
# Validate the environment cookie
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
client.cookies.set("_ncfa", env_cookie, domain="www.geoguessr.com")
|
||||
response = await client.get(f"{GEOGUESSR_BASE_URL}/v3/profiles")
|
||||
|
||||
if response.status_code == 200:
|
||||
profile = response.json()
|
||||
return {
|
||||
"authenticated": True,
|
||||
"method": "environment_variable",
|
||||
"username": profile.get("nick", "Unknown"),
|
||||
"user_id": profile.get("id", "Unknown"),
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"authenticated": False,
|
||||
"message": "Not authenticated. Use 'login' with your GeoGuessr credentials or 'set_ncfa_cookie' with a valid cookie.",
|
||||
}
|
||||
0
src/geoguessr_mcp/tools/competitive_tools.py
Normal file
0
src/geoguessr_mcp/tools/competitive_tools.py
Normal file
0
src/geoguessr_mcp/tools/game_tools.py
Normal file
0
src/geoguessr_mcp/tools/game_tools.py
Normal file
26
src/geoguessr_mcp/tools/profile_tools.py
Normal file
26
src/geoguessr_mcp/tools/profile_tools.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
"""MCP tools for profile operations."""
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from ..services.profile_service import ProfileService
|
||||
|
||||
|
||||
def register_profile_tools(mcp: FastMCP, profile_service: ProfileService):
|
||||
"""Register profile-related tools."""
|
||||
|
||||
@mcp.tool()
|
||||
async def get_my_profile(session_token: str = "") -> dict:
|
||||
"""Get the current user's profile information."""
|
||||
profile = await profile_service.get_profile(session_token if session_token else None)
|
||||
return {
|
||||
"id": profile.id,
|
||||
"nick": profile.nick,
|
||||
"email": profile.email,
|
||||
"country": profile.country,
|
||||
"level": profile.level,
|
||||
}
|
||||
|
||||
@mcp.tool()
|
||||
async def get_my_stats(session_token: str = "") -> dict:
|
||||
"""Get the current user's statistics."""
|
||||
return await profile_service.get_stats(session_token if session_token else None)
|
||||
1089
src/server.py
1089
src/server.py
File diff suppressed because it is too large
Load diff
54
src/tests/conftest.py
Normal file
54
src/tests/conftest.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""Shared test fixtures."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_env(monkeypatch):
|
||||
"""Set up environment variables for testing."""
|
||||
monkeypatch.setenv("GEOGUESSR_NCFA_COOKIE", "test_cookie_value")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Create a mock async HTTP session."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_profile_data():
|
||||
"""Standard profile response data."""
|
||||
return {
|
||||
"id": "test-user-id",
|
||||
"nick": "TestPlayer",
|
||||
"email": "test@example.com",
|
||||
"country": "FR",
|
||||
"created": "2025-01-01T00:00:00.000Z",
|
||||
"isVerified": True,
|
||||
"level": 50,
|
||||
"rating": {
|
||||
"rating": 1500,
|
||||
"deviation": 100
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_game_data():
|
||||
"""Standard game response data."""
|
||||
return {
|
||||
"token": "ABC123",
|
||||
"type": "standard",
|
||||
"map": {"name": "World"},
|
||||
"player": {
|
||||
"guesses": [
|
||||
{"roundScoreInPoints": 5000, "distanceInMeters": 0, "time": 10},
|
||||
{"roundScoreInPoints": 4500, "distanceInMeters": 100, "time": 15},
|
||||
]
|
||||
},
|
||||
}
|
||||
0
src/tests/e2e/__init__.py
Normal file
0
src/tests/e2e/__init__.py
Normal file
16
src/tests/e2e/test_full_workflow.py
Normal file
16
src/tests/e2e/test_full_workflow.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""Integration tests for authentication."""
|
||||
|
||||
import pytest
|
||||
|
||||
from geoguessr_mcp.auth.session import SessionManager
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAuthFlow:
|
||||
"""Integration tests for authentication flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_login_logout_cycle(self):
|
||||
"""Test complete login and logout cycle."""
|
||||
# This would use real API calls in a test environment
|
||||
pass
|
||||
0
src/tests/integration/__init__.py
Normal file
0
src/tests/integration/__init__.py
Normal file
0
src/tests/integration/test_api_client.py
Normal file
0
src/tests/integration/test_api_client.py
Normal file
0
src/tests/integration/test_auth_flow.py
Normal file
0
src/tests/integration/test_auth_flow.py
Normal file
|
|
@ -1,201 +0,0 @@
|
|||
"""
|
||||
Tests for GeoGuessr MCP Server
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import httpx
|
||||
|
||||
|
||||
# Mock the environment variable before importing server
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_env(monkeypatch):
|
||||
"""Set up environment variables for testing."""
|
||||
monkeypatch.setenv("GEOGUESSR_NCFA_COOKIE", "test_cookie_value")
|
||||
|
||||
|
||||
class TestProfileTools:
|
||||
"""Tests for profile-related tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_profile_success(self):
|
||||
"""Test successful profile retrieval."""
|
||||
from server import get_my_profile
|
||||
|
||||
mock_response = {
|
||||
"id": "test-user-id",
|
||||
"nick": "TestPlayer",
|
||||
"country": "US",
|
||||
"level": 50,
|
||||
}
|
||||
|
||||
with patch("server.get_async_session") as mock_session:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
|
||||
mock_http_response = MagicMock()
|
||||
mock_http_response.json.return_value = mock_response
|
||||
mock_http_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client.get = AsyncMock(return_value=mock_http_response)
|
||||
mock_session.return_value = mock_client
|
||||
|
||||
result = await get_my_profile()
|
||||
|
||||
assert result["nick"] == "TestPlayer"
|
||||
assert result["id"] == "test-user-id"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_stats_success(self):
|
||||
"""Test successful stats retrieval."""
|
||||
from server import get_my_stats
|
||||
|
||||
mock_response = {
|
||||
"gamesPlayed": 100,
|
||||
"averageScore": 4500,
|
||||
"highScore": 5000,
|
||||
}
|
||||
|
||||
with patch("server.get_async_session") as mock_session:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
|
||||
mock_http_response = MagicMock()
|
||||
mock_http_response.json.return_value = mock_response
|
||||
mock_http_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client.get = AsyncMock(return_value=mock_http_response)
|
||||
mock_session.return_value = mock_client
|
||||
|
||||
result = await get_my_stats()
|
||||
|
||||
assert result["gamesPlayed"] == 100
|
||||
assert result["averageScore"] == 4500
|
||||
|
||||
|
||||
class TestGameTools:
|
||||
"""Tests for game-related tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_game_details_success(self):
|
||||
"""Test successful game details retrieval."""
|
||||
from server import get_game_details
|
||||
|
||||
mock_response = {
|
||||
"token": "ABC123",
|
||||
"type": "standard",
|
||||
"map": {"name": "World"},
|
||||
"player": {
|
||||
"guesses": [
|
||||
{"roundScoreInPoints": 5000, "distanceInMeters": 0},
|
||||
{"roundScoreInPoints": 4500, "distanceInMeters": 100},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
with patch("server.get_async_session") as mock_session:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
|
||||
mock_http_response = MagicMock()
|
||||
mock_http_response.json.return_value = mock_response
|
||||
mock_http_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client.get = AsyncMock(return_value=mock_http_response)
|
||||
mock_session.return_value = mock_client
|
||||
|
||||
result = await get_game_details("ABC123")
|
||||
|
||||
assert result["token"] == "ABC123"
|
||||
assert result["map"]["name"] == "World"
|
||||
assert len(result["player"]["guesses"]) == 2
|
||||
|
||||
|
||||
class TestAnalysisTools:
|
||||
"""Tests for analysis tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_recent_games_empty(self):
|
||||
"""Test analysis with no games in feed."""
|
||||
from server import analyze_recent_games
|
||||
|
||||
mock_feed_response = {"entries": []}
|
||||
|
||||
with patch("server.get_async_session") as mock_session:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
|
||||
mock_http_response = MagicMock()
|
||||
mock_http_response.json.return_value = mock_feed_response
|
||||
mock_http_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client.get = AsyncMock(return_value=mock_http_response)
|
||||
mock_session.return_value = mock_client
|
||||
|
||||
result = await analyze_recent_games(count=5)
|
||||
|
||||
assert result["games_analyzed"] == 0
|
||||
assert result["total_score"] == 0
|
||||
assert result["games"] == []
|
||||
|
||||
|
||||
class TestAuthentication:
|
||||
"""Tests for authentication handling."""
|
||||
|
||||
def test_get_ncfa_cookie_missing(self, monkeypatch):
|
||||
"""Test error when cookie is not set."""
|
||||
monkeypatch.delenv("GEOGUESSR_NCFA_COOKIE", raising=False)
|
||||
|
||||
from server import get_ncfa_cookie
|
||||
|
||||
with pytest.raises(ValueError, match="GEOGUESSR_NCFA_COOKIE"):
|
||||
get_ncfa_cookie()
|
||||
|
||||
def test_get_ncfa_cookie_present(self, monkeypatch):
|
||||
"""Test cookie retrieval when set."""
|
||||
monkeypatch.setenv("GEOGUESSR_NCFA_COOKIE", "my_test_cookie")
|
||||
|
||||
from server import get_ncfa_cookie
|
||||
|
||||
cookie = get_ncfa_cookie()
|
||||
assert cookie == "my_test_cookie"
|
||||
|
||||
|
||||
# Integration tests (marked to skip by default)
|
||||
@pytest.mark.integration
|
||||
class TestIntegration:
|
||||
"""Integration tests that require a real GeoGuessr cookie."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_profile_fetch(self):
|
||||
"""Test fetching real profile data."""
|
||||
import os
|
||||
if not os.environ.get("GEOGUESSR_NCFA_COOKIE") or \
|
||||
os.environ.get("GEOGUESSR_NCFA_COOKIE") == "test_cookie_value":
|
||||
pytest.skip("Real NCFA cookie not configured")
|
||||
|
||||
from server import get_my_profile
|
||||
|
||||
result = await get_my_profile()
|
||||
assert "nick" in result
|
||||
assert "id" in result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Run tests automatically when script is executed directly."""
|
||||
import sys
|
||||
|
||||
# Run pytest with verbose output and show print statements
|
||||
exit_code = pytest.main([
|
||||
__file__,
|
||||
"-v", # Verbose output
|
||||
"-s", # Show print statements
|
||||
"--tb=short", # Shorter traceback format
|
||||
"-m", "not integration", # Skip integration tests by default
|
||||
])
|
||||
|
||||
sys.exit(exit_code)
|
||||
0
src/tests/unit/__init__.py
Normal file
0
src/tests/unit/__init__.py
Normal file
0
src/tests/unit/test_analysis_service.py
Normal file
0
src/tests/unit/test_analysis_service.py
Normal file
0
src/tests/unit/test_game_service.py
Normal file
0
src/tests/unit/test_game_service.py
Normal file
81
src/tests/unit/test_profile_service.py
Normal file
81
src/tests/unit/test_profile_service.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
"""Unit tests for ProfileService."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from geoguessr_mcp.models.profile import UserProfile
|
||||
from geoguessr_mcp.services.profile_service import ProfileService
|
||||
from geoguessr_mcp.config import settings
|
||||
|
||||
|
||||
class TestProfileService:
|
||||
"""Tests for ProfileService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_profile_success(self, mock_session, mock_profile_data):
|
||||
"""Test successful profile retrieval."""
|
||||
# Create mock client
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = settings.GEOGUESSR_BASE_URL
|
||||
mock_client.get_async_session = AsyncMock(return_value=mock_session)
|
||||
|
||||
# Mock HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_profile_data
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_session.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Test
|
||||
service = ProfileService(mock_client)
|
||||
profile = await service.get_profile()
|
||||
|
||||
assert isinstance(profile, UserProfile)
|
||||
assert profile.nick == "TestPlayer"
|
||||
assert profile.id == "test-user-id"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_stats_success(self, mock_session, mock_profile_data):
|
||||
"""Test successful stats retrieval."""
|
||||
# Create mock client
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = settings.GEOGUESSR_BASE_URL
|
||||
mock_client.get_async_session = AsyncMock(return_value=mock_session)
|
||||
|
||||
# Mock HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_profile_data
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_session.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
service = ProfileService(mock_client)
|
||||
profile = await service.get_stats()
|
||||
|
||||
assert isinstance(profile, UserProfile)
|
||||
assert profile. == 100
|
||||
assert result["averageScore"] == 4500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_extended_stats(self, mock_session):
|
||||
"""Test extended stats retrieval."""
|
||||
from server import get_extended_stats
|
||||
|
||||
extended_stats = {
|
||||
"totalGames": 150,
|
||||
"winRate": 0.65,
|
||||
"averageTime": 180
|
||||
}
|
||||
|
||||
with patch("server.get_async_session") as mock_get_session:
|
||||
mock_http_response = MagicMock()
|
||||
mock_http_response.json.return_value = extended_stats
|
||||
mock_http_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_session.get = AsyncMock(return_value=mock_http_response)
|
||||
mock_get_session.return_value = mock_session
|
||||
|
||||
result = await get_extended_stats()
|
||||
|
||||
assert result["totalGames"] == 150
|
||||
assert result["winRate"] == 0.65
|
||||
169
src/tests/unit/test_session.py
Normal file
169
src/tests/unit/test_session.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
"""Unit tests for session management."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from geoguessr_mcp.auth.session import SessionManager, UserSession
|
||||
|
||||
# ============================================================================
|
||||
# USER SESSION TESTS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestUserSession:
|
||||
"""Tests for UserSession dataclass."""
|
||||
|
||||
def test_valid_session(self):
|
||||
"""Test that a valid session is recognized as valid."""
|
||||
session = UserSession(
|
||||
ncfa_cookie="test_cookie",
|
||||
user_id="user123",
|
||||
username="TestUser",
|
||||
email="test@example.com",
|
||||
expires_at=datetime.now(UTC) + timedelta(days=1),
|
||||
)
|
||||
assert session.is_valid()
|
||||
|
||||
def test_expired_session(self):
|
||||
"""Test that an expired session is invalid."""
|
||||
session = UserSession(
|
||||
ncfa_cookie="test_cookie",
|
||||
user_id="user123",
|
||||
username="TestUser",
|
||||
email="test@example.com",
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
)
|
||||
assert not session.is_valid()
|
||||
|
||||
def test_session_without_cookie(self):
|
||||
"""Test that a session without cookie is invalid."""
|
||||
session = UserSession(
|
||||
ncfa_cookie="", user_id="user123", username="TestUser", email="test@example.com"
|
||||
)
|
||||
assert not session.is_valid()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SESSION MANAGER TESTS
|
||||
# ============================================================================
|
||||
|
||||
class TestSessionManager:
|
||||
"""Tests for SessionManager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_success(self, mock_profile_response):
|
||||
"""Test successful login flow."""
|
||||
|
||||
manager = SessionManager()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client_class:
|
||||
# Create mock client
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock login response
|
||||
login_response = MagicMock()
|
||||
login_response.status_code = 200
|
||||
login_response.cookies.jar = []
|
||||
|
||||
# Create mock cookie
|
||||
mock_cookie = MagicMock()
|
||||
mock_cookie.name = "_ncfa"
|
||||
mock_cookie.value = "test_ncfa_cookie_value"
|
||||
login_response.cookies.jar.append(mock_cookie)
|
||||
|
||||
# Mock profile response
|
||||
profile_response = MagicMock()
|
||||
profile_response.status_code = 200
|
||||
profile_response.json.return_value = mock_profile_response
|
||||
|
||||
# Set up mock client responses
|
||||
mock_client.post = AsyncMock(return_value=login_response)
|
||||
mock_client.get = AsyncMock(return_value=profile_response)
|
||||
mock_client.cookies.set = MagicMock()
|
||||
|
||||
# Perform login
|
||||
session_token, session = await manager.login("test@example.com", "password123")
|
||||
|
||||
# Assertions
|
||||
assert session_token is not None
|
||||
assert len(session_token) > 0
|
||||
assert session.ncfa_cookie == "test_ncfa_cookie_value"
|
||||
assert session.user_id == "test-user-id"
|
||||
assert session.username == "TestPlayer"
|
||||
assert session.email == "test@example.com"
|
||||
assert session.is_valid()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_invalid_credentials(self):
|
||||
"""Test login with invalid credentials."""
|
||||
manager = SessionManager()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock 401 response
|
||||
login_response = MagicMock()
|
||||
login_response.status_code = 401
|
||||
mock_client.post = AsyncMock(return_value=login_response)
|
||||
|
||||
# Attempt login and expect error
|
||||
with pytest.raises(ValueError, match="Invalid email or password"):
|
||||
await manager.login("wrong@example.com", "wrong_pass")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout(self, mock_profile_response):
|
||||
"""Test logout functionality."""
|
||||
|
||||
manager = SessionManager()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client_class:
|
||||
# Set up successful login first
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
login_response = MagicMock()
|
||||
login_response.status_code = 200
|
||||
login_response.cookies.jar = []
|
||||
mock_cookie = MagicMock()
|
||||
mock_cookie.name = "_ncfa"
|
||||
mock_cookie.value = "test_cookie"
|
||||
login_response.cookies.jar.append(mock_cookie)
|
||||
|
||||
profile_response = MagicMock()
|
||||
profile_response.status_code = 200
|
||||
profile_response.json.return_value = mock_profile_response
|
||||
|
||||
mock_client.post = AsyncMock(return_value=login_response)
|
||||
mock_client.get = AsyncMock(return_value=profile_response)
|
||||
mock_client.cookies.set = MagicMock()
|
||||
|
||||
session_token, _ = await manager.login("test@example.com", "password")
|
||||
|
||||
# Now logout
|
||||
result = await manager.logout(session_token)
|
||||
assert result is True
|
||||
|
||||
# Verify session is removed
|
||||
session = await manager.get_session(session_token)
|
||||
assert session is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_with_default_cookie(self):
|
||||
"""Test getting session with default cookie from environment."""
|
||||
|
||||
manager = SessionManager()
|
||||
|
||||
# Should use default cookie from environment
|
||||
session = await manager.get_session()
|
||||
assert session is not None
|
||||
assert session.ncfa_cookie == "test_cookie_value"
|
||||
Loading…
Add table
Add a link
Reference in a new issue