This commit is contained in:
leo
2023-11-04 18:10:56 +08:00
commit da3b1a9f34
34 changed files with 1082 additions and 0 deletions

13
app/api/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from fastapi import APIRouter
import app.api.form
import app.api.result
import app.api.bench
api_router = APIRouter()
api_router.include_router(bench.router, prefix='/benches', tags=['benches'])
api_router.include_router(form.router, prefix='/forms', tags=['forms'])
api_router.include_router(result.router, prefix='/results', tags=['results'])

44
app/api/bench.py Normal file
View File

@@ -0,0 +1,44 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from fastapi import APIRouter, Body, Depends
from loguru import logger
from sqlalchemy.orm import Session
from app.db.base import get_db
from app.crud import form_crud, result_crud
from app.schemas import Bench
from app.schemas.form import Form
router = APIRouter()
@router.get("/forms", response_model=list[Form])
def read_bench_forms(db: Session = Depends(get_db), limit: int = 10):
return form_crud.form.get_last_bench(db, num=limit)
@router.post("/result")
def read_bench_forms(uuid: str = Body(), db: Session = Depends(get_db), ):
logger.debug(f"{uuid=}")
result_dbs = result_crud.result.get_by_uuid(db, uuid=uuid)
res = {
"百川": {"value": ""},
"ChatGPT": {"value": ""},
"MyTwins": {"value": ""},
}
for result in result_dbs:
res[result.name] = result
return res
@router.get("/", response_model=list[Bench])
def read_forms(db: Session = Depends(get_db), limit: int = 10):
res = []
form_dbs = form_crud.form.get_last_bench(db, num=limit)
for form_db in form_dbs:
res.append(
Bench(
form=form_db,
results=result_crud.result.get_by_uuid(db, uuid=form_db.uuid)
))
return res

43
app/api/form.py Normal file
View File

@@ -0,0 +1,43 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.db.base import get_db
from app.crud import form_crud
from app.schemas import Form, FormCreate, FormUpdate
router = APIRouter()
@router.put("/{item_id}", response_model=Form)
def update_form(item_id: int, form_in: FormUpdate, db: Session = Depends(get_db)):
form_db = form_crud.form.get(db, item_id)
if not form_db:
raise HTTPException(status_code=404, detail="Form not found")
return form_crud.form.update(db=db, db_obj=form_db, obj_in=form_in)
@router.delete("/{item_id}")
def delete_form(item_id: int, db: Session = Depends(get_db)):
form_db = form_crud.form.get(db, item_id)
if not form_db:
raise HTTPException(status_code=404, detail="Form not found")
return form_crud.form.remove(db=db, id=item_id)
@router.get("/{item_id}", response_model=Form)
def read_form(item_id: int, db: Session = Depends(get_db)):
return form_crud.form.get(db, item_id)
@router.post("/", response_model=Form)
def create_form(form_in: FormCreate, db: Session = Depends(get_db)):
form_db = form_crud.form.create(db=db, obj_in=form_in)
return form_db
@router.get("/", response_model=list[Form])
def read_forms(db: Session = Depends(get_db), skip: int = 0, limit: int = 100):
return form_crud.form.get_multi(db, skip=skip, limit=limit)

43
app/api/result.py Normal file
View File

@@ -0,0 +1,43 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.db.base import get_db
from app.crud import result_crud
from app.schemas import Result, ResultCreate, ResultUpdate
router = APIRouter()
@router.put("/{item_id}", response_model=Result)
def update_result(item_id: int, result_in: ResultUpdate, db: Session = Depends(get_db)):
result_db = result_crud.result.get(db, item_id)
if not result_db:
raise HTTPException(status_code=404, detail="Result not found")
return result_crud.result.update(db=db, db_obj=result_db, obj_in=result_in)
@router.delete("/{item_id}")
def delete_result(item_id: int, db: Session = Depends(get_db)):
result_db = result_crud.result.get(db, item_id)
if not result_db:
raise HTTPException(status_code=404, detail="Result not found")
return result_crud.result.remove(db=db, id=item_id)
@router.get("/{item_id}", response_model=Result)
def read_result(item_id: int, db: Session = Depends(get_db)):
return result_crud.result.get(db, item_id)
@router.post("/", response_model=Result)
def create_result(result_in: ResultCreate, db: Session = Depends(get_db)):
result_db = result_crud.result.create(db=db, obj_in=result_in)
return result_db
@router.get("/", response_model=list[Result])
def read_results(db: Session = Depends(get_db), skip: int = 0, limit: int = 100):
return result_crud.result.get_multi(db, skip=skip, limit=limit)

2
app/core/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

14
app/core/config.py Normal file
View File

@@ -0,0 +1,14 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import os
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
mysql_dsn: str = 'mysql+pymysql://root:123456@127.0.0.1:3306/model?charset=utf8mb4'
settings = Settings()

2
app/crud/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

82
app/crud/base.py Normal file
View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
from loguru import logger
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from app.db.base import Base
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
"""
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
**Parameters**
* `model`: A SQLAlchemy model class
* `schema`: A Pydantic model (schema) class
"""
self.model = model
def get(self, db: Session, id: Any) -> Optional[ModelType]:
return db.query(self.model).filter(self.model.id == id).first()
def page_query(self) -> Select:
return select(self.model)
def get_multi(
self, db: Session, *, skip: int = 0, limit: int = 100
) -> List[ModelType]:
return db.query(self.model).offset(skip).limit(limit).all()
def get_multi_query(self) -> Select:
return select(self.model)
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
obj_in_data = obj_in.model_dump()
logger.debug(f"{obj_in_data=}")
db_obj = self.model(**obj_in_data) # type: ignore
db.add(db_obj)
db.commit()
db.refresh(db_obj)
logger.debug(f"created {self.model.__name__}: {db_obj.id=}")
return db_obj
def update(
self,
db: Session,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
) -> ModelType:
obj_data = obj_in.model_dump()
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def remove(self, db: Session, *, id: int) -> ModelType:
obj = db.query(self.model).get(id)
db.delete(obj)
db.commit()
return obj
def get_by_ids(self, db: Session, ids: list[int]):
return db.query(self.model).filter(self.model.id.in_(ids)).all()

30
app/crud/form_crud.py Normal file
View File

@@ -0,0 +1,30 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from sqlalchemy import and_, func
from sqlalchemy.orm import Session
from app.models.form import Form
from app.schemas import FormCreate, FormUpdate
from app.crud.base import CRUDBase
class CRUDForm(CRUDBase[Form, FormUpdate, FormCreate]):
def get_last_bench(self, db: Session, num: int = 10):
# 创建子查询获取每个不同的name的最大id
subquery = (
db.query(self.model.name, func.max(self.model.id).label('max_id'))
.group_by(self.model.name)
.subquery()
)
# 主查询与子查询连接获取每个name的最后一条数据
results = (
db.query(self.model)
.join(subquery, and_(self.model.id == subquery.c.max_id))
.all()
)
return results
form = CRUDForm(Form)

15
app/crud/result_crud.py Normal file
View File

@@ -0,0 +1,15 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from sqlalchemy.orm import Session
from app.models.result import Result
from app.schemas import ResultCreate, ResultUpdate
from app.crud.base import CRUDBase
class CRUDResult(CRUDBase[Result, ResultUpdate, ResultCreate]):
def get_by_uuid(self, db: Session, uuid: str) -> Result:
return db.query(self.model).filter(self.model.uuid == uuid).all()
result = CRUDResult(Result)

2
app/db/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

38
app/db/base.py Normal file
View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from contextlib import contextmanager
from typing import Generator
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import declarative_base
from app.core.config import settings
Base = declarative_base()
SQLALCHEMY_DATABASE_URL = settings.mysql_dsn
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, echo=False, pool_size=10, max_overflow=20)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# @contextmanager
# def get_db() -> Generator:
# db = SessionLocal()
# try:
# yield db
# finally:
# db.close()
def get_db() -> Generator:
try:
db = SessionLocal()
yield db
finally:
db.close()
if __name__ == '__main__':
with get_db() as db:
print(db.info)

2
app/models/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

26
app/models/form.py Normal file
View File

@@ -0,0 +1,26 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from sqlalchemy import Column, Integer, String, Float, TIMESTAMP, text, func
from app.db.base import Base
class Form(Base):
__tablename__ = 'form'
id = Column(Integer, primary_key=True, nullable=False)
base_prompt = Column(String(length=1024))
prompt = Column(String(length=2048))
p_choice = Column(String(length=255))
role = Column(String(length=255))
name = Column(String(length=1024))
uuid = Column(String(length=1024))
desc = Column(String(length=2048))
price = Column(Float)
favorable = Column(String(length=1024))
remark = Column(String(length=1024))
otherPrompt = Column(String(length=1024))
lang = Column(String(length=1024))
type = Column(String(length=1024))
created_at = Column(TIMESTAMP, server_default=func.now())
updated_at = Column(TIMESTAMP, server_default=text('CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP'))

18
app/models/result.py Normal file
View File

@@ -0,0 +1,18 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from sqlalchemy import Column, Integer, String, TIMESTAMP, func, text
from app.db.base import Base
class Result(Base):
__tablename__ = 'result'
id = Column(Integer, primary_key=True, nullable=False)
prompt = Column(String())
name = Column(String())
uuid = Column(String())
value = Column(String())
lang = Column(String())
created_at = Column(TIMESTAMP, server_default=func.now())
updated_at = Column(TIMESTAMP, server_default=text('CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP'))

5
app/schemas/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from .form import Form, FormCreate, FormUpdate
from .result import Result, ResultCreate, ResultUpdate
from .bench import Bench

10
app/schemas/bench.py Normal file
View File

@@ -0,0 +1,10 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from pydantic import BaseModel
from app.schemas import Form, Result
class Bench(BaseModel):
form: Form
results: list[Result]

44
app/schemas/form.py Normal file
View File

@@ -0,0 +1,44 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from datetime import datetime
from pydantic import BaseModel
from pydantic.v1 import validator
class FormBase(BaseModel):
base_prompt: str | None = None
prompt: str | None = None
p_choice: str | None = None
role: str | None = None
name: str | None = None
uuid: str | None = None
desc: str | None = None
price: float | None = None
favorable: str | None = None
remark: str | None = None
otherPrompt: str | None = None
lang: str | None = None
type: str | None = None
class Form(FormBase):
id: int
created_at: datetime = datetime.now()
updated_at: datetime = datetime.now()
# custom input conversion for that field
_normalize_datetimes = validator(
"created_at", "updated_at",
allow_reuse=True)(lambda v: v.timestamp())
class Config:
from_attributes = True
class FormCreate(FormBase):
...
class FormUpdate(FormBase):
...

34
app/schemas/result.py Normal file
View File

@@ -0,0 +1,34 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
from datetime import datetime
from pydantic import BaseModel
from pydantic.v1 import validator
class ResultBase(BaseModel):
prompt: str | None = None
name: str | None = None
uuid: str | None = None
value: str | None = None
lang: str | None = None
class Result(ResultBase):
id: int
created_at: datetime = datetime.now()
updated_at: datetime = datetime.now()
_normalize_datetimes = validator(
"created_at", "updated_at",
allow_reuse=True)(lambda v: v.timestamp())
class Config:
from_attributes = True
class ResultCreate(ResultBase):
...
class ResultUpdate(ResultBase):
...