Разделил модели базы данных и модели pydantic.

This commit is contained in:
DmitryGantimurov 2023-07-17 23:34:04 +03:00
parent 808edad6b4
commit 09ba6a3478
5 changed files with 44 additions and 114 deletions

View File

@ -18,61 +18,4 @@ engine = create_engine(
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False) SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
database = SessionLocal() database = SessionLocal()
Base = declarative_base() Base = declarative_base()
class UserDatabase(Base):#класс пользователя
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)#айди пользователя
phone = Column(Integer, nullable=True)#номер телефона пользователя
email = Column(String)#электронная почта пользователя
password = Column(String) # пароль
hashed_password = Column(String)
name = Column(String, nullable=True)#имя пользователя
surname = Column(String)#фамилия пользователя
class Announcement(Base): #класс объявления
__tablename__ = "announcements"
id = Column(Integer, primary_key=True, index=True)#айди объявления
user_id = Column(Integer)#айди создателя объявления
name = Column(String) # название объявления
category = Column(String)#категория продукта из объявления
best_by = Column(Integer)#срок годности продукта из объявления
address = Column(String)
longtitude = Column(Integer)
latitude = Column(Integer)
description = Column(String)#описание продукта в объявлении
src = Column(String, nullable=True) #изображение продукта в объявлении
metro = Column(String)#ближайщее метро от адреса нахождения продукта
trashId = Column(Integer, nullable=True)
booked_by = Column(Integer)#статус бронирования (либо -1, либо айди бронирующего)
class Trashbox(Base):#класс мусорных баков
__tablename__ = "trashboxes"
id = Column(Integer, primary_key=True, index=True)#айди
name = Column(String, nullable=True)#имя пользователя
address = Column(String)
latitude = Column(Integer)
longtitude = Column(Integer)
category = Column(String)#категория продукта из объявления
# ### Функции понадобятся, когда приложение приложение будет более развитым
# # This function can be called during the initialization of the FastAPI app.
# async def create_db_and_tables():
# async with engine.begin() as conn:
# await conn.run_sync(Base.metadata.create_all)
# async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
# async with async_session_maker() as session:
# yield session
# async def get_user_db(session: AsyncSession = Depends(get_async_session)):
# yield SQLAlchemyUserDatabase(session, User)

View File

@ -20,7 +20,8 @@ import shutil
import os import os
from .utils import * from .utils import *
from .db import Announcement, Trashbox, UserDatabase, Base, engine, SessionLocal, database from .db import Base, engine, SessionLocal, database
from .models import Announcement, Trashbox, UserDatabase
from . import schema from . import schema

View File

@ -1,23 +1,7 @@
from typing import AsyncGenerator
from sqlalchemy import Column, Integer, String from sqlalchemy import Column, Integer, String
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from fastapi import Depends from fastapi import Depends
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase from .db import Base
SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db"
engine = create_async_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = async_sessionmaker(bind=engine, expire_on_commit=False)
Base = declarative_base()
class UserDatabase(Base):#класс пользователя class UserDatabase(Base):#класс пользователя
__tablename__ = "users" __tablename__ = "users"
@ -60,16 +44,19 @@ class Trashbox(Base):#класс мусорных баков
category = Column(String)#категория продукта из объявления category = Column(String)#категория продукта из объявления
# This function can be called during the initialization of the FastAPI app. # from typing import AsyncGenerator
async def create_db_and_tables(): # from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
async with engine.begin() as conn: # from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
await conn.run_sync(Base.metadata.create_all) # # This function can be called during the initialization of the FastAPI app.
# async def create_db_and_tables():
# async with engine.begin() as conn:
# await conn.run_sync(Base.metadata.create_all)
async def get_async_session() -> AsyncGenerator[AsyncSession, None]: # async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session_maker() as session: # async with async_session_maker() as session:
yield session # yield session
async def get_user_db(session: AsyncSession = Depends(get_async_session)): # async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(session, User) # yield SQLAlchemyUserDatabase(session, User)

View File

@ -1,5 +1,24 @@
from pydantic import BaseModel from pydantic import BaseModel
from typing import Annotated, Union
class Book(BaseModel): class Book(BaseModel):
id: int id: int
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
email: Union[str, None] = None
class User(BaseModel):
email: Union[str, None] = None
full_name: Union[str, None] = None
disabled: Union[bool, None] = None
class UserInDB(User):
hashed_password: str

View File

@ -5,15 +5,15 @@ from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt from jose import JWTError, jwt
from passlib.context import CryptContext from passlib.context import CryptContext
from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import select from sqlalchemy import select
from .db import UserDatabase, SessionLocal, database from .db import SessionLocal, database
from .models import UserDatabase
from .schema import Token, TokenData, UserInDB, User
# to get a string like this run:
# openssl rand -hex 32
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256" ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30 ACCESS_TOKEN_EXPIRE_MINUTES = 30
@ -29,26 +29,6 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
# } # }
# } # }
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
email: Union[str, None] = None
class User(BaseModel):
email: Union[str, None] = None
full_name: Union[str, None] = None
disabled: Union[bool, None] = None
class UserInDB(User):
hashed_password: str
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@ -63,9 +43,9 @@ def get_password_hash(password):
# проблема здесь # проблема здесь
def get_user(db: SessionLocal, email: str): def get_user(db: SessionLocal, email: str):
user_with_required_email = db.query(UserDatabase).filter(UserDatabase.email == email).first() user_with_required_email = db.query(UserDatabase).filter(UserDatabase.email == email).one()
if user_with_required_email: if user_with_required_email:
return UserInDB(user_with_required_email) ## выдает ошибку о том, что 2 аргумента, хотя я передаю 1 return user_with_required_email
return None return None
@ -89,7 +69,7 @@ def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None
return encoded_jwt return encoded_jwt
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): async def get_current_user(db: SessionLocal, token: Annotated[str, Depends(oauth2_scheme)]):
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials", detail="Could not validate credentials",
@ -103,7 +83,7 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
token_data = TokenData(email=email) token_data = TokenData(email=email)
except JWTError: except JWTError:
raise credentials_exception raise credentials_exception
user = get_user(fake_users_db, email=token_data.email) user = get_user(db, email=token_data.email)
if user is None: if user is None:
raise credentials_exception raise credentials_exception
return user return user