From 2028dd94193458ac935bf2ea443141bd2511f296 Mon Sep 17 00:00:00 2001 From: dm1sh Date: Tue, 1 Aug 2023 12:15:23 +0300 Subject: [PATCH] Fixed get_current_user db argument setting --- back/main.py | 2 +- back/utils.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/back/main.py b/back/main.py index c9f994d..772fc43 100644 --- a/back/main.py +++ b/back/main.py @@ -172,7 +172,7 @@ async def login_for_access_token( # async def read_users_me(current_user: Annotated[schemas.User, Depends(get_current_active_user)]): # return current_user #schemas.User(id=current_user.id, email=current_user.email, name=current_user.name, surname=current_user.surname, disabled=current_user.disabled, items=current_user.items) @app.get("/api/users/me/", response_model=schemas.User) # -async def read_users_me(current_user: Annotated[schemas.User, Depends(get_current_active_user)]) -> Any: +async def read_users_me(current_user: Annotated[schemas.User, Depends(get_current_active_user)]): return current_user #schemas.User(id=current_user.id, email=current_user.email, name=current_user.name, surname=current_user.surname, disabled=current_user.disabled, items=current_user.items) diff --git a/back/utils.py b/back/utils.py index 4b70789..64d378e 100644 --- a/back/utils.py +++ b/back/utils.py @@ -12,6 +12,7 @@ from sqlalchemy import select # from .db import Session, database from . import models, schemas +from .db import SessionLocal @@ -22,6 +23,14 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() + + def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password, hashed_password) @@ -57,7 +66,7 @@ def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None return encoded_jwt -async def get_current_user(db: Session, token: Annotated[str, Depends(oauth2_scheme)]): +async def get_current_user(db: Annotated[Session, Depends(get_db)], token: Annotated[str, Depends(oauth2_scheme)]): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", @@ -88,10 +97,3 @@ async def get_current_active_user( # def get_db(request: Request): # return request.state.db - -# def get_db(): -# db = SessionLocal() -# try: -# yield db -# finally: -# db.close() \ No newline at end of file