base.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
  2. from fastapi.encoders import jsonable_encoder
  3. from pydantic import BaseModel
  4. from sqlalchemy.orm import Session
  5. from app.db.base_class import Base
  6. ModelType = TypeVar("ModelType", bound=Base)
  7. CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
  8. UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
  9. class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
  10. def __init__(self, model: Type[ModelType]):
  11. """
  12. CRUD object with default methods to Create, Read, Update, Delete (CRUD).
  13. **Parameters**
  14. * `model`: A SQLAlchemy model class
  15. * `schema`: A Pydantic model (schema) class
  16. """
  17. self.model = model
  18. def get(self, db: Session, id: Any) -> Optional[ModelType]:
  19. return db.query(self.model).filter(self.model.id == id).first()
  20. def get_multi(
  21. self, db: Session, *, skip: int = 0, limit: int = 100
  22. ) -> List[ModelType]:
  23. return db.query(self.model).offset(skip).limit(limit).all()
  24. def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
  25. obj_in_data = jsonable_encoder(obj_in)
  26. db_obj = self.model(**obj_in_data) # type: ignore
  27. db.add(db_obj)
  28. db.commit()
  29. db.refresh(db_obj)
  30. return db_obj
  31. def update(
  32. self,
  33. db: Session,
  34. *,
  35. db_obj: ModelType,
  36. obj_in: Union[UpdateSchemaType, Dict[str, Any]]
  37. ) -> ModelType:
  38. obj_data = jsonable_encoder(db_obj)
  39. if isinstance(obj_in, dict):
  40. update_data = obj_in
  41. else:
  42. update_data = obj_in.dict(exclude_unset=True)
  43. for field in obj_data:
  44. if field in update_data:
  45. setattr(db_obj, field, update_data[field])
  46. db.add(db_obj)
  47. db.commit()
  48. db.refresh(db_obj)
  49. return db_obj
  50. def remove(self, db: Session, *, id: int) -> ModelType:
  51. obj = db.query(self.model).get(id)
  52. db.delete(obj)
  53. db.commit()
  54. return obj