#!/usr/bin/env python3 # -*- coding:utf-8 -*- from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union from loguru import logger from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session from sqlalchemy.sql import Select from app.db.base import Base ModelType = TypeVar("ModelType", bound=Base) CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): def __init__(self, model: Type[ModelType]): """ CRUD object with default methods to Create, Read, Update, Delete (CRUD). **Parameters** * `model`: A SQLAlchemy model class * `schema`: A Pydantic model (schema) class """ self.model = model def get(self, db: Session, id: Any) -> Optional[ModelType]: return db.query(self.model).filter(self.model.id == id).first() def page_query(self) -> Select: return select(self.model) def get_multi( self, db: Session, *, skip: int = 0, limit: int = 100 ) -> List[ModelType]: return db.query(self.model).offset(skip).limit(limit).all() def get_multi_query(self) -> Select: return select(self.model) def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType: obj_in_data = obj_in.model_dump() logger.debug(f"{obj_in_data=}") db_obj = self.model(**obj_in_data) # type: ignore db.add(db_obj) db.commit() db.refresh(db_obj) logger.debug(f"created {self.model.__name__}: {db_obj.id=}") return db_obj def update( self, db: Session, *, db_obj: ModelType, obj_in: Union[UpdateSchemaType, Dict[str, Any]] ) -> ModelType: obj_data = obj_in.model_dump() if isinstance(obj_in, dict): update_data = obj_in else: update_data = obj_in.model_dump(exclude_unset=True) for field in obj_data: if field in update_data: setattr(db_obj, field, update_data[field]) db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj def remove(self, db: Session, *, id: int) -> ModelType: obj = db.query(self.model).get(id) db.delete(obj) db.commit() return obj def get_by_ids(self, db: Session, ids: list[int]): return db.query(self.model).filter(self.model.id.in_(ids)).all()