/security-auditor skills changes done

This commit is contained in:
2026-04-24 12:58:46 +05:30
parent 7bee838bb0
commit 373adc776f
6 changed files with 241 additions and 293 deletions

View File

@@ -23,7 +23,8 @@
"Bash(curl:*)", "Bash(curl:*)",
"Bash(npx lighthouse:*)", "Bash(npx lighthouse:*)",
"Bash(echo \"exit:$?\")", "Bash(echo \"exit:$?\")",
"Bash(python -c \"from config import get_settings; s = get_settings\\(\\); print\\('SA JSON set:', bool\\(s.firebase_service_account_json\\)\\)\")" "Bash(python -c \"from config import get_settings; s = get_settings\\(\\); print\\('SA JSON set:', bool\\(s.firebase_service_account_json\\)\\)\")",
"Bash(python3 -c ' *)"
] ]
} }
} }

47
backend/auth.py Normal file
View File

@@ -0,0 +1,47 @@
"""Firebase token verification and ownership helpers."""
import logging
from fastapi import HTTPException, Header
from firebase_admin import auth as firebase_auth
import firebase_admin
from bson import ObjectId
from bson.errors import InvalidId
log = logging.getLogger(__name__)
async def get_current_user(authorization: str = Header(..., alias="Authorization")) -> dict:
"""FastAPI dependency: verifies Firebase ID token and returns decoded payload."""
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = authorization[len("Bearer "):]
if not firebase_admin._apps:
raise HTTPException(status_code=503, detail="Authentication service unavailable")
try:
return firebase_auth.verify_id_token(token)
except firebase_auth.ExpiredIdTokenError:
raise HTTPException(status_code=401, detail="Token expired")
except Exception:
raise HTTPException(status_code=401, detail="Invalid token")
def verify_user_access(user_id: str, db, token: dict) -> dict:
"""
Fetch user by ObjectId and confirm the token owner matches.
Returns the user document. Raises 400/404/403 on failure.
"""
try:
user_oid = ObjectId(user_id)
except InvalidId:
raise HTTPException(status_code=400, detail="Invalid user ID format")
user = db.users.find_one({"_id": user_oid})
if not user:
raise HTTPException(status_code=404, detail="User not found")
if user.get("email") != token.get("email"):
raise HTTPException(status_code=403, detail="Access denied")
return user

View File

@@ -52,7 +52,7 @@ app.add_middleware(
allow_origins=cors_origins, allow_origins=cors_origins,
allow_credentials=True, allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["*"], allow_headers=["Authorization", "Content-Type"],
) )
# Include routers # Include routers

View File

@@ -1,13 +1,13 @@
"""Journal entry routes""" """Journal entry routes"""
from fastapi import APIRouter, HTTPException, Query import logging
from fastapi import APIRouter, HTTPException, Query, Depends
from db import get_database from db import get_database
from models import JournalEntryCreate, JournalEntryUpdate, JournalEntry, EntriesListResponse, PaginationMeta from models import JournalEntryCreate, JournalEntryUpdate
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Optional from auth import get_current_user, verify_user_access
from bson import ObjectId
from bson.errors import InvalidId
from utils import format_ist_timestamp from utils import format_ist_timestamp
log = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@@ -16,21 +16,20 @@ def _format_entry(entry: dict) -> dict:
return { return {
"id": str(entry["_id"]), "id": str(entry["_id"]),
"userId": str(entry["userId"]), "userId": str(entry["userId"]),
"title": entry.get("title"), # None if encrypted "title": entry.get("title"),
"content": entry.get("content"), # None if encrypted "content": entry.get("content"),
"mood": entry.get("mood"), "mood": entry.get("mood"),
"tags": entry.get("tags", []), "tags": entry.get("tags", []),
"isPublic": entry.get("isPublic", False), "isPublic": entry.get("isPublic", False),
"entryDate": entry.get("entryDate", entry.get("createdAt")).isoformat() if entry.get("entryDate") or entry.get("createdAt") else None, "entryDate": entry.get("entryDate", entry.get("createdAt")).isoformat() if entry.get("entryDate") or entry.get("createdAt") else None,
"createdAt": entry["createdAt"].isoformat(), "createdAt": entry["createdAt"].isoformat(),
"updatedAt": entry["updatedAt"].isoformat(), "updatedAt": entry["updatedAt"].isoformat(),
# Full encryption metadata including ciphertext and nonce
"encryption": entry.get("encryption") "encryption": entry.get("encryption")
} }
@router.post("/{user_id}", response_model=dict) @router.post("/{user_id}", response_model=dict)
async def create_entry(user_id: str, entry_data: JournalEntryCreate): async def create_entry(user_id: str, entry_data: JournalEntryCreate, token: dict = Depends(get_current_user)):
""" """
Create a new journal entry. Create a new journal entry.
@@ -38,33 +37,18 @@ async def create_entry(user_id: str, entry_data: JournalEntryCreate):
- Send encryption metadata with ciphertext and nonce - Send encryption metadata with ciphertext and nonce
- Omit title and content (they're encrypted in ciphertext) - Omit title and content (they're encrypted in ciphertext)
For unencrypted entries (deprecated):
- Send title and content directly
entryDate: The logical journal date for this entry (defaults to today UTC). entryDate: The logical journal date for this entry (defaults to today UTC).
createdAt: Database write timestamp.
Server stores only: encrypted ciphertext, nonce, and metadata. Server stores only: encrypted ciphertext, nonce, and metadata.
Server never sees plaintext. Server never sees plaintext.
""" """
db = get_database() db = get_database()
try: try:
user_oid = ObjectId(user_id) user = verify_user_access(user_id, db, token)
except InvalidId: user_oid = user["_id"]
raise HTTPException(status_code=400, detail="Invalid user ID format")
try:
# Verify user exists
user = db.users.find_one({"_id": user_oid})
if not user:
raise HTTPException(status_code=404, detail="User not found")
now = datetime.utcnow() now = datetime.utcnow()
entry_date = entry_data.entryDate or now.replace( entry_date = entry_data.entryDate or now.replace(hour=0, minute=0, second=0, microsecond=0)
hour=0, minute=0, second=0, microsecond=0)
# Validate encryption metadata if present
if entry_data.encryption: if entry_data.encryption:
if not entry_data.encryption.ciphertext or not entry_data.encryption.nonce: if not entry_data.encryption.ciphertext or not entry_data.encryption.nonce:
raise HTTPException( raise HTTPException(
@@ -74,12 +58,12 @@ async def create_entry(user_id: str, entry_data: JournalEntryCreate):
entry_doc = { entry_doc = {
"userId": user_oid, "userId": user_oid,
"title": entry_data.title, # None if encrypted "title": entry_data.title,
"content": entry_data.content, # None if encrypted "content": entry_data.content,
"mood": entry_data.mood, "mood": entry_data.mood,
"tags": entry_data.tags or [], "tags": entry_data.tags or [],
"isPublic": entry_data.isPublic or False, "isPublic": entry_data.isPublic or False,
"entryDate": entry_date, # Logical journal date "entryDate": entry_date,
"createdAt": now, "createdAt": now,
"updatedAt": now, "updatedAt": now,
"encryption": entry_data.encryption.model_dump() if entry_data.encryption else None "encryption": entry_data.encryption.model_dump() if entry_data.encryption else None
@@ -94,48 +78,29 @@ async def create_entry(user_id: str, entry_data: JournalEntryCreate):
} }
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception:
raise HTTPException( log.exception("Failed to create entry")
status_code=500, detail=f"Failed to create entry: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/{user_id}") @router.get("/{user_id}")
async def get_user_entries( async def get_user_entries(
user_id: str, user_id: str,
limit: int = Query(50, ge=1, le=100), limit: int = Query(50, ge=1, le=100),
skip: int = Query(0, ge=0) skip: int = Query(0, ge=0),
token: dict = Depends(get_current_user)
): ):
""" """Get paginated entries for a user (most recent first)."""
Get paginated entries for a user (most recent first).
Supports pagination via skip and limit.
"""
db = get_database() db = get_database()
try: try:
user_oid = ObjectId(user_id) user = verify_user_access(user_id, db, token)
except InvalidId: user_oid = user["_id"]
raise HTTPException(status_code=400, detail="Invalid user ID format")
try:
# Verify user exists
user = db.users.find_one({"_id": user_oid})
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Get entries
entries = list( entries = list(
db.entries.find( db.entries.find({"userId": user_oid}).sort("createdAt", -1).skip(skip).limit(limit)
{"userId": user_oid}
).sort("createdAt", -1).skip(skip).limit(limit)
) )
# Format entries
formatted_entries = [_format_entry(entry) for entry in entries] formatted_entries = [_format_entry(entry) for entry in entries]
# Get total count
total = db.entries.count_documents({"userId": user_oid}) total = db.entries.count_documents({"userId": user_oid})
has_more = (skip + limit) < total
return { return {
"entries": formatted_entries, "entries": formatted_entries,
@@ -143,101 +108,95 @@ async def get_user_entries(
"total": total, "total": total,
"limit": limit, "limit": limit,
"skip": skip, "skip": skip,
"hasMore": has_more "hasMore": (skip + limit) < total
} }
} }
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception:
raise HTTPException( log.exception("Failed to fetch entries")
status_code=500, detail=f"Failed to fetch entries: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/{user_id}/{entry_id}") @router.get("/{user_id}/{entry_id}")
async def get_entry(user_id: str, entry_id: str): async def get_entry(user_id: str, entry_id: str, token: dict = Depends(get_current_user)):
"""Get a specific entry by ID.""" """Get a specific entry by ID."""
from bson import ObjectId
from bson.errors import InvalidId
db = get_database() db = get_database()
try: try:
user_oid = ObjectId(user_id) user = verify_user_access(user_id, db, token)
user_oid = user["_id"]
try:
entry_oid = ObjectId(entry_id) entry_oid = ObjectId(entry_id)
except InvalidId: except InvalidId:
raise HTTPException(status_code=400, detail="Invalid ID format") raise HTTPException(status_code=400, detail="Invalid entry ID format")
try:
entry = db.entries.find_one({
"_id": entry_oid,
"userId": user_oid
})
entry = db.entries.find_one({"_id": entry_oid, "userId": user_oid})
if not entry: if not entry:
raise HTTPException(status_code=404, detail="Entry not found") raise HTTPException(status_code=404, detail="Entry not found")
return _format_entry(entry) return _format_entry(entry)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception:
raise HTTPException( log.exception("Failed to fetch entry")
status_code=500, detail=f"Failed to fetch entry: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error")
@router.put("/{user_id}/{entry_id}") @router.put("/{user_id}/{entry_id}")
async def update_entry(user_id: str, entry_id: str, entry_data: JournalEntryUpdate): async def update_entry(user_id: str, entry_id: str, entry_data: JournalEntryUpdate, token: dict = Depends(get_current_user)):
"""Update a journal entry.""" """Update a journal entry."""
from bson import ObjectId
from bson.errors import InvalidId
db = get_database() db = get_database()
try: try:
user_oid = ObjectId(user_id) user = verify_user_access(user_id, db, token)
user_oid = user["_id"]
try:
entry_oid = ObjectId(entry_id) entry_oid = ObjectId(entry_id)
except InvalidId: except InvalidId:
raise HTTPException(status_code=400, detail="Invalid ID format") raise HTTPException(status_code=400, detail="Invalid entry ID format")
try:
update_data = entry_data.model_dump(exclude_unset=True) update_data = entry_data.model_dump(exclude_unset=True)
update_data["updatedAt"] = datetime.utcnow() update_data["updatedAt"] = datetime.utcnow()
# If entryDate provided in update data, ensure it's a datetime
if "entryDate" in update_data and isinstance(update_data["entryDate"], str): if "entryDate" in update_data and isinstance(update_data["entryDate"], str):
update_data["entryDate"] = datetime.fromisoformat( update_data["entryDate"] = datetime.fromisoformat(
update_data["entryDate"].replace("Z", "+00:00")) update_data["entryDate"].replace("Z", "+00:00"))
result = db.entries.update_one( result = db.entries.update_one(
{ {"_id": entry_oid, "userId": user_oid},
"_id": entry_oid,
"userId": user_oid
},
{"$set": update_data} {"$set": update_data}
) )
if result.matched_count == 0: if result.matched_count == 0:
raise HTTPException(status_code=404, detail="Entry not found") raise HTTPException(status_code=404, detail="Entry not found")
# Fetch and return updated entry
entry = db.entries.find_one({"_id": entry_oid}) entry = db.entries.find_one({"_id": entry_oid})
return _format_entry(entry) return _format_entry(entry)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception:
raise HTTPException( log.exception("Failed to update entry")
status_code=500, detail=f"Failed to update entry: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("/{user_id}/{entry_id}") @router.delete("/{user_id}/{entry_id}")
async def delete_entry(user_id: str, entry_id: str): async def delete_entry(user_id: str, entry_id: str, token: dict = Depends(get_current_user)):
"""Delete a journal entry.""" """Delete a journal entry."""
from bson import ObjectId
from bson.errors import InvalidId
db = get_database() db = get_database()
try: try:
user_oid = ObjectId(user_id) user = verify_user_access(user_id, db, token)
user_oid = user["_id"]
try:
entry_oid = ObjectId(entry_id) entry_oid = ObjectId(entry_id)
except InvalidId: except InvalidId:
raise HTTPException(status_code=400, detail="Invalid ID format") raise HTTPException(status_code=400, detail="Invalid entry ID format")
try: result = db.entries.delete_one({"_id": entry_oid, "userId": user_oid})
result = db.entries.delete_one({
"_id": entry_oid,
"userId": user_oid
})
if result.deleted_count == 0: if result.deleted_count == 0:
raise HTTPException(status_code=404, detail="Entry not found") raise HTTPException(status_code=404, detail="Entry not found")
@@ -245,108 +204,83 @@ async def delete_entry(user_id: str, entry_id: str):
return {"message": "Entry deleted successfully"} return {"message": "Entry deleted successfully"}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception:
raise HTTPException( log.exception("Failed to delete entry")
status_code=500, detail=f"Failed to delete entry: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/{user_id}/by-date/{date_str}") @router.get("/{user_id}/by-date/{date_str}")
async def get_entries_by_date(user_id: str, date_str: str): async def get_entries_by_date(user_id: str, date_str: str, token: dict = Depends(get_current_user)):
""" """Get entries for a specific date (format: YYYY-MM-DD)."""
Get entries for a specific date (format: YYYY-MM-DD).
Matches entries by entryDate field.
"""
db = get_database() db = get_database()
try:
user = verify_user_access(user_id, db, token)
user_oid = user["_id"]
try: try:
user_oid = ObjectId(user_id)
except InvalidId:
raise HTTPException(status_code=400, detail="Invalid user ID format")
try:
# Parse date
target_date = datetime.strptime(date_str, "%Y-%m-%d") target_date = datetime.strptime(date_str, "%Y-%m-%d")
except ValueError:
raise HTTPException(status_code=400, detail="Invalid date format. Use YYYY-MM-DD")
next_date = target_date + timedelta(days=1) next_date = target_date + timedelta(days=1)
entries = list( entries = list(
db.entries.find({ db.entries.find({
"userId": user_oid, "userId": user_oid,
"entryDate": { "entryDate": {"$gte": target_date, "$lt": next_date}
"$gte": target_date,
"$lt": next_date
}
}).sort("createdAt", -1) }).sort("createdAt", -1)
) )
formatted_entries = [_format_entry(entry) for entry in entries]
return { return {
"entries": formatted_entries, "entries": [_format_entry(e) for e in entries],
"date": date_str, "date": date_str,
"count": len(formatted_entries) "count": len(entries)
} }
except ValueError:
raise HTTPException(
status_code=400, detail="Invalid date format. Use YYYY-MM-DD")
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception:
raise HTTPException( log.exception("Failed to fetch entries by date")
status_code=500, detail=f"Failed to fetch entries: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/{user_id}/by-month/{year}/{month}") @router.get("/{user_id}/by-month/{year}/{month}")
async def get_entries_by_month(user_id: str, year: int, month: int, limit: int = Query(100, ge=1, le=500)): async def get_entries_by_month(
""" user_id: str,
Get entries for a specific month (for calendar view). year: int,
month: int,
Query format: GET /api/entries/{user_id}/by-month/{year}/{month}?limit=100 limit: int = Query(100, ge=1, le=500),
""" token: dict = Depends(get_current_user)
):
"""Get entries for a specific month (for calendar view)."""
db = get_database() db = get_database()
try: try:
user_oid = ObjectId(user_id) user = verify_user_access(user_id, db, token)
except InvalidId: user_oid = user["_id"]
raise HTTPException(status_code=400, detail="Invalid user ID format")
try:
if not (1 <= month <= 12): if not (1 <= month <= 12):
raise HTTPException( raise HTTPException(status_code=400, detail="Month must be between 1 and 12")
status_code=400, detail="Month must be between 1 and 12")
# Calculate date range
start_date = datetime(year, month, 1) start_date = datetime(year, month, 1)
if month == 12: end_date = datetime(year + 1, 1, 1) if month == 12 else datetime(year, month + 1, 1)
end_date = datetime(year + 1, 1, 1)
else:
end_date = datetime(year, month + 1, 1)
entries = list( entries = list(
db.entries.find({ db.entries.find({
"userId": user_oid, "userId": user_oid,
"entryDate": { "entryDate": {"$gte": start_date, "$lt": end_date}
"$gte": start_date,
"$lt": end_date
}
}).sort("entryDate", -1).limit(limit) }).sort("entryDate", -1).limit(limit)
) )
formatted_entries = [_format_entry(entry) for entry in entries]
return { return {
"entries": formatted_entries, "entries": [_format_entry(e) for e in entries],
"year": year, "year": year,
"month": month, "month": month,
"count": len(formatted_entries) "count": len(entries)
} }
except ValueError:
raise HTTPException(status_code=400, detail="Invalid year or month")
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception:
raise HTTPException( log.exception("Failed to fetch entries by month")
status_code=500, detail=f"Failed to fetch entries: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/convert-timestamp/utc-to-ist") @router.post("/convert-timestamp/utc-to-ist")
@@ -355,18 +289,14 @@ async def convert_utc_to_ist(data: dict):
try: try:
utc_timestamp = data.get("timestamp") utc_timestamp = data.get("timestamp")
if not utc_timestamp: if not utc_timestamp:
raise HTTPException( raise HTTPException(status_code=400, detail="Missing 'timestamp' field")
status_code=400, detail="Missing 'timestamp' field")
ist_timestamp = format_ist_timestamp(utc_timestamp) ist_timestamp = format_ist_timestamp(utc_timestamp)
return { return {"utc": utc_timestamp, "ist": ist_timestamp}
"utc": utc_timestamp,
"ist": ist_timestamp
}
except HTTPException: except HTTPException:
raise raise
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
except Exception as e: except Exception:
raise HTTPException( log.exception("Timestamp conversion failed")
status_code=500, detail=f"Conversion failed: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -1,12 +1,13 @@
"""Notification routes — FCM token registration and reminder settings.""" """Notification routes — FCM token registration and reminder settings."""
from fastapi import APIRouter, HTTPException import logging
from fastapi import APIRouter, HTTPException, Depends
from db import get_database from db import get_database
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
from bson import ObjectId
from bson.errors import InvalidId
from datetime import datetime from datetime import datetime
from auth import get_current_user, verify_user_access
log = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@@ -22,23 +23,16 @@ class ReminderSettingsRequest(BaseModel):
@router.post("/fcm-token", response_model=dict) @router.post("/fcm-token", response_model=dict)
async def register_fcm_token(body: FcmTokenRequest): async def register_fcm_token(body: FcmTokenRequest, token: dict = Depends(get_current_user)):
""" """
Register (or refresh) an FCM device token for a user. Register (or refresh) an FCM device token for a user.
Stores unique tokens per user — duplicate tokens are ignored. Stores unique tokens per user — duplicate tokens are ignored.
""" """
db = get_database() db = get_database()
try: try:
user_oid = ObjectId(body.userId) user = verify_user_access(body.userId, db, token)
except InvalidId: user_oid = user["_id"]
raise HTTPException(status_code=400, detail="Invalid user ID")
user = db.users.find_one({"_id": user_oid})
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Add token to set (avoid duplicates)
db.users.update_one( db.users.update_one(
{"_id": user_oid}, {"_id": user_oid},
{ {
@@ -47,23 +41,20 @@ async def register_fcm_token(body: FcmTokenRequest):
} }
) )
return {"message": "FCM token registered"} return {"message": "FCM token registered"}
except HTTPException:
raise
except Exception:
log.exception("Failed to register FCM token")
raise HTTPException(status_code=500, detail="Internal server error")
@router.put("/reminder/{user_id}", response_model=dict) @router.put("/reminder/{user_id}", response_model=dict)
async def update_reminder(user_id: str, settings: ReminderSettingsRequest): async def update_reminder(user_id: str, settings: ReminderSettingsRequest, token: dict = Depends(get_current_user)):
""" """Save or update daily reminder settings for a user."""
Save or update daily reminder settings for a user.
"""
db = get_database() db = get_database()
try: try:
user_oid = ObjectId(user_id) user = verify_user_access(user_id, db, token)
except InvalidId: user_oid = user["_id"]
raise HTTPException(status_code=400, detail="Invalid user ID")
user = db.users.find_one({"_id": user_oid})
if not user:
raise HTTPException(status_code=404, detail="User not found")
reminder_update: dict = {"reminder.enabled": settings.enabled} reminder_update: dict = {"reminder.enabled": settings.enabled}
if settings.time is not None: if settings.time is not None:
@@ -76,3 +67,8 @@ async def update_reminder(user_id: str, settings: ReminderSettingsRequest):
{"$set": {**reminder_update, "updatedAt": datetime.utcnow()}} {"$set": {**reminder_update, "updatedAt": datetime.utcnow()}}
) )
return {"message": "Reminder settings updated"} return {"message": "Reminder settings updated"}
except HTTPException:
raise
except Exception:
log.exception("Failed to update reminder settings")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -1,17 +1,17 @@
"""User management routes""" """User management routes"""
from fastapi import APIRouter, HTTPException import logging
from fastapi import APIRouter, HTTPException, Depends
from db import get_database from db import get_database
from models import UserCreate, UserUpdate, User from models import UserCreate, UserUpdate
from datetime import datetime from datetime import datetime
from typing import Optional from auth import get_current_user, verify_user_access
from bson import ObjectId
from bson.errors import InvalidId
log = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@router.post("/register", response_model=dict) @router.post("/register", response_model=dict)
async def register_user(user_data: UserCreate): async def register_user(user_data: UserCreate, token: dict = Depends(get_current_user)):
""" """
Register or get user (idempotent). Register or get user (idempotent).
@@ -19,10 +19,11 @@ async def register_user(user_data: UserCreate):
If user already exists, returns existing user. If user already exists, returns existing user.
Called after Firebase Google Auth on frontend. Called after Firebase Google Auth on frontend.
""" """
db = get_database() if user_data.email != token.get("email"):
raise HTTPException(status_code=403, detail="Access denied")
db = get_database()
try: try:
# Upsert: Update if exists, insert if not
result = db.users.update_one( result = db.users.update_one(
{"email": user_data.email}, {"email": user_data.email},
{ {
@@ -40,11 +41,9 @@ async def register_user(user_data: UserCreate):
upsert=True upsert=True
) )
# Fetch the user (either newly created or existing)
user = db.users.find_one({"email": user_data.email}) user = db.users.find_one({"email": user_data.email})
if not user: if not user:
raise HTTPException( raise HTTPException(status_code=500, detail="Failed to retrieve user after upsert")
status_code=500, detail="Failed to retrieve user after upsert")
return { return {
"id": str(user["_id"]), "id": str(user["_id"]),
@@ -62,15 +61,17 @@ async def register_user(user_data: UserCreate):
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException( log.exception("Registration failed")
status_code=500, detail=f"Registration failed: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/by-email/{email}", response_model=dict) @router.get("/by-email/{email}", response_model=dict)
async def get_user_by_email(email: str): async def get_user_by_email(email: str, token: dict = Depends(get_current_user)):
"""Get user profile by email (called after Firebase Auth).""" """Get user profile by email (called after Firebase Auth)."""
db = get_database() if email != token.get("email"):
raise HTTPException(status_code=403, detail="Access denied")
db = get_database()
try: try:
user = db.users.find_one({"email": email}) user = db.users.find_one({"email": email})
if not user: if not user:
@@ -91,26 +92,17 @@ async def get_user_by_email(email: str):
} }
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception:
raise HTTPException( log.exception("Failed to fetch user by email")
status_code=500, detail=f"Failed to fetch user: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/{user_id}", response_model=dict) @router.get("/{user_id}", response_model=dict)
async def get_user_by_id(user_id: str): async def get_user_by_id(user_id: str, token: dict = Depends(get_current_user)):
"""Get user profile by ID.""" """Get user profile by ID."""
db = get_database() db = get_database()
try: try:
user_oid = ObjectId(user_id) user = verify_user_access(user_id, db, token)
except InvalidId:
raise HTTPException(status_code=400, detail="Invalid user ID format")
try:
user = db.users.find_one({"_id": user_oid})
if not user:
raise HTTPException(status_code=404, detail="User not found")
return { return {
"id": str(user["_id"]), "id": str(user["_id"]),
"email": user["email"], "email": user["email"],
@@ -124,72 +116,54 @@ async def get_user_by_id(user_id: str):
} }
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception:
raise HTTPException( log.exception("Failed to fetch user by ID")
status_code=500, detail=f"Failed to fetch user: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error")
@router.put("/{user_id}", response_model=dict) @router.put("/{user_id}", response_model=dict)
async def update_user(user_id: str, user_data: UserUpdate): async def update_user(user_id: str, user_data: UserUpdate, token: dict = Depends(get_current_user)):
"""Update user profile.""" """Update user profile."""
db = get_database() db = get_database()
try: try:
user_oid = ObjectId(user_id) user = verify_user_access(user_id, db, token)
except InvalidId: user_oid = user["_id"]
raise HTTPException(status_code=400, detail="Invalid user ID format")
try:
# Prepare update data (exclude None values)
update_data = user_data.model_dump(exclude_unset=True) update_data = user_data.model_dump(exclude_unset=True)
update_data["updatedAt"] = datetime.utcnow() update_data["updatedAt"] = datetime.utcnow()
result = db.users.update_one( db.users.update_one({"_id": user_oid}, {"$set": update_data})
{"_id": user_oid},
{"$set": update_data}
)
if result.matched_count == 0: updated = db.users.find_one({"_id": user_oid})
raise HTTPException(status_code=404, detail="User not found")
# Fetch and return updated user
user = db.users.find_one({"_id": user_oid})
return { return {
"id": str(user["_id"]), "id": str(updated["_id"]),
"email": user["email"], "email": updated["email"],
"displayName": user.get("displayName"), "displayName": updated.get("displayName"),
"photoURL": user.get("photoURL"), "photoURL": updated.get("photoURL"),
"theme": user.get("theme", "light"), "theme": updated.get("theme", "light"),
"backgroundImage": user.get("backgroundImage"), "backgroundImage": updated.get("backgroundImage"),
"backgroundImages": user.get("backgroundImages", []), "backgroundImages": updated.get("backgroundImages", []),
"tutorial": user.get("tutorial"), "tutorial": updated.get("tutorial"),
"createdAt": user["createdAt"].isoformat(), "createdAt": updated["createdAt"].isoformat(),
"updatedAt": user["updatedAt"].isoformat(), "updatedAt": updated["updatedAt"].isoformat(),
"message": "User updated successfully" "message": "User updated successfully"
} }
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception:
raise HTTPException(status_code=500, detail=f"Update failed: {str(e)}") log.exception("User update failed")
raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("/{user_id}") @router.delete("/{user_id}")
async def delete_user(user_id: str): async def delete_user(user_id: str, token: dict = Depends(get_current_user)):
"""Delete user account and all associated data.""" """Delete user account and all associated data."""
db = get_database() db = get_database()
try: try:
user_oid = ObjectId(user_id) user = verify_user_access(user_id, db, token)
except InvalidId: user_oid = user["_id"]
raise HTTPException(status_code=400, detail="Invalid user ID format")
try:
# Delete user
user_result = db.users.delete_one({"_id": user_oid}) user_result = db.users.delete_one({"_id": user_oid})
if user_result.deleted_count == 0:
raise HTTPException(status_code=404, detail="User not found")
# Delete all user's entries
entry_result = db.entries.delete_many({"userId": user_oid}) entry_result = db.entries.delete_many({"userId": user_oid})
return { return {
@@ -199,6 +173,6 @@ async def delete_user(user_id: str):
} }
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception:
raise HTTPException( log.exception("User deletion failed")
status_code=500, detail=f"Deletion failed: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error")