init
This commit is contained in:
13
app/api/__init__.py
Normal file
13
app/api/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
from fastapi import APIRouter
|
||||
|
||||
import app.api.form
|
||||
import app.api.result
|
||||
import app.api.bench
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
api_router.include_router(bench.router, prefix='/benches', tags=['benches'])
|
||||
api_router.include_router(form.router, prefix='/forms', tags=['forms'])
|
||||
api_router.include_router(result.router, prefix='/results', tags=['results'])
|
||||
44
app/api/bench.py
Normal file
44
app/api/bench.py
Normal file
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.base import get_db
|
||||
from app.crud import form_crud, result_crud
|
||||
from app.schemas import Bench
|
||||
from app.schemas.form import Form
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/forms", response_model=list[Form])
|
||||
def read_bench_forms(db: Session = Depends(get_db), limit: int = 10):
|
||||
return form_crud.form.get_last_bench(db, num=limit)
|
||||
|
||||
|
||||
@router.post("/result")
|
||||
def read_bench_forms(uuid: str = Body(), db: Session = Depends(get_db), ):
|
||||
logger.debug(f"{uuid=}")
|
||||
result_dbs = result_crud.result.get_by_uuid(db, uuid=uuid)
|
||||
res = {
|
||||
"百川": {"value": ""},
|
||||
"ChatGPT": {"value": ""},
|
||||
"MyTwins": {"value": ""},
|
||||
}
|
||||
for result in result_dbs:
|
||||
res[result.name] = result
|
||||
return res
|
||||
|
||||
|
||||
@router.get("/", response_model=list[Bench])
|
||||
def read_forms(db: Session = Depends(get_db), limit: int = 10):
|
||||
res = []
|
||||
form_dbs = form_crud.form.get_last_bench(db, num=limit)
|
||||
for form_db in form_dbs:
|
||||
res.append(
|
||||
Bench(
|
||||
form=form_db,
|
||||
results=result_crud.result.get_by_uuid(db, uuid=form_db.uuid)
|
||||
))
|
||||
return res
|
||||
43
app/api/form.py
Normal file
43
app/api/form.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.base import get_db
|
||||
from app.crud import form_crud
|
||||
from app.schemas import Form, FormCreate, FormUpdate
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.put("/{item_id}", response_model=Form)
|
||||
def update_form(item_id: int, form_in: FormUpdate, db: Session = Depends(get_db)):
|
||||
form_db = form_crud.form.get(db, item_id)
|
||||
if not form_db:
|
||||
raise HTTPException(status_code=404, detail="Form not found")
|
||||
return form_crud.form.update(db=db, db_obj=form_db, obj_in=form_in)
|
||||
|
||||
|
||||
@router.delete("/{item_id}")
|
||||
def delete_form(item_id: int, db: Session = Depends(get_db)):
|
||||
form_db = form_crud.form.get(db, item_id)
|
||||
if not form_db:
|
||||
raise HTTPException(status_code=404, detail="Form not found")
|
||||
return form_crud.form.remove(db=db, id=item_id)
|
||||
|
||||
|
||||
@router.get("/{item_id}", response_model=Form)
|
||||
def read_form(item_id: int, db: Session = Depends(get_db)):
|
||||
return form_crud.form.get(db, item_id)
|
||||
|
||||
|
||||
@router.post("/", response_model=Form)
|
||||
def create_form(form_in: FormCreate, db: Session = Depends(get_db)):
|
||||
form_db = form_crud.form.create(db=db, obj_in=form_in)
|
||||
return form_db
|
||||
|
||||
|
||||
@router.get("/", response_model=list[Form])
|
||||
def read_forms(db: Session = Depends(get_db), skip: int = 0, limit: int = 100):
|
||||
return form_crud.form.get_multi(db, skip=skip, limit=limit)
|
||||
43
app/api/result.py
Normal file
43
app/api/result.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.base import get_db
|
||||
from app.crud import result_crud
|
||||
from app.schemas import Result, ResultCreate, ResultUpdate
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.put("/{item_id}", response_model=Result)
|
||||
def update_result(item_id: int, result_in: ResultUpdate, db: Session = Depends(get_db)):
|
||||
result_db = result_crud.result.get(db, item_id)
|
||||
if not result_db:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
return result_crud.result.update(db=db, db_obj=result_db, obj_in=result_in)
|
||||
|
||||
|
||||
@router.delete("/{item_id}")
|
||||
def delete_result(item_id: int, db: Session = Depends(get_db)):
|
||||
result_db = result_crud.result.get(db, item_id)
|
||||
if not result_db:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
return result_crud.result.remove(db=db, id=item_id)
|
||||
|
||||
|
||||
@router.get("/{item_id}", response_model=Result)
|
||||
def read_result(item_id: int, db: Session = Depends(get_db)):
|
||||
return result_crud.result.get(db, item_id)
|
||||
|
||||
|
||||
@router.post("/", response_model=Result)
|
||||
def create_result(result_in: ResultCreate, db: Session = Depends(get_db)):
|
||||
result_db = result_crud.result.create(db=db, obj_in=result_in)
|
||||
return result_db
|
||||
|
||||
|
||||
@router.get("/", response_model=list[Result])
|
||||
def read_results(db: Session = Depends(get_db), skip: int = 0, limit: int = 100):
|
||||
return result_crud.result.get_multi(db, skip=skip, limit=limit)
|
||||
2
app/core/__init__.py
Normal file
2
app/core/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
14
app/core/config.py
Normal file
14
app/core/config.py
Normal file
@@ -0,0 +1,14 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import os
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
mysql_dsn: str = 'mysql+pymysql://root:123456@127.0.0.1:3306/model?charset=utf8mb4'
|
||||
|
||||
|
||||
settings = Settings()
|
||||
2
app/crud/__init__.py
Normal file
2
app/crud/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
82
app/crud/base.py
Normal file
82
app/crud/base.py
Normal file
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
"""
|
||||
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
|
||||
|
||||
**Parameters**
|
||||
|
||||
* `model`: A SQLAlchemy model class
|
||||
* `schema`: A Pydantic model (schema) class
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
def get(self, db: Session, id: Any) -> Optional[ModelType]:
|
||||
return db.query(self.model).filter(self.model.id == id).first()
|
||||
|
||||
def page_query(self) -> Select:
|
||||
return select(self.model)
|
||||
|
||||
def get_multi(
|
||||
self, db: Session, *, skip: int = 0, limit: int = 100
|
||||
) -> List[ModelType]:
|
||||
return db.query(self.model).offset(skip).limit(limit).all()
|
||||
|
||||
def get_multi_query(self) -> Select:
|
||||
return select(self.model)
|
||||
|
||||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
obj_in_data = obj_in.model_dump()
|
||||
logger.debug(f"{obj_in_data=}")
|
||||
db_obj = self.model(**obj_in_data) # type: ignore
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
logger.debug(f"created {self.model.__name__}: {db_obj.id=}")
|
||||
return db_obj
|
||||
|
||||
def update(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
) -> ModelType:
|
||||
obj_data = obj_in.model_dump()
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def remove(self, db: Session, *, id: int) -> ModelType:
|
||||
obj = db.query(self.model).get(id)
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return obj
|
||||
|
||||
def get_by_ids(self, db: Session, ids: list[int]):
|
||||
return db.query(self.model).filter(self.model.id.in_(ids)).all()
|
||||
30
app/crud/form_crud.py
Normal file
30
app/crud/form_crud.py
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.form import Form
|
||||
from app.schemas import FormCreate, FormUpdate
|
||||
from app.crud.base import CRUDBase
|
||||
|
||||
|
||||
class CRUDForm(CRUDBase[Form, FormUpdate, FormCreate]):
|
||||
def get_last_bench(self, db: Session, num: int = 10):
|
||||
# 创建子查询,获取每个不同的name的最大id
|
||||
subquery = (
|
||||
db.query(self.model.name, func.max(self.model.id).label('max_id'))
|
||||
.group_by(self.model.name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# 主查询,与子查询连接,获取每个name的最后一条数据
|
||||
results = (
|
||||
db.query(self.model)
|
||||
.join(subquery, and_(self.model.id == subquery.c.max_id))
|
||||
.all()
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
form = CRUDForm(Form)
|
||||
15
app/crud/result_crud.py
Normal file
15
app/crud/result_crud.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.result import Result
|
||||
from app.schemas import ResultCreate, ResultUpdate
|
||||
from app.crud.base import CRUDBase
|
||||
|
||||
|
||||
class CRUDResult(CRUDBase[Result, ResultUpdate, ResultCreate]):
|
||||
def get_by_uuid(self, db: Session, uuid: str) -> Result:
|
||||
return db.query(self.model).filter(self.model.uuid == uuid).all()
|
||||
|
||||
|
||||
result = CRUDResult(Result)
|
||||
2
app/db/__init__.py
Normal file
2
app/db/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
38
app/db/base.py
Normal file
38
app/db/base.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
Base = declarative_base()
|
||||
SQLALCHEMY_DATABASE_URL = settings.mysql_dsn
|
||||
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, echo=False, pool_size=10, max_overflow=20)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
# @contextmanager
|
||||
# def get_db() -> Generator:
|
||||
# db = SessionLocal()
|
||||
# try:
|
||||
# yield db
|
||||
# finally:
|
||||
# db.close()
|
||||
|
||||
|
||||
def get_db() -> Generator:
|
||||
try:
|
||||
db = SessionLocal()
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with get_db() as db:
|
||||
print(db.info)
|
||||
2
app/models/__init__.py
Normal file
2
app/models/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
26
app/models/form.py
Normal file
26
app/models/form.py
Normal file
@@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Float, TIMESTAMP, text, func
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class Form(Base):
|
||||
__tablename__ = 'form'
|
||||
|
||||
id = Column(Integer, primary_key=True, nullable=False)
|
||||
base_prompt = Column(String(length=1024))
|
||||
prompt = Column(String(length=2048))
|
||||
p_choice = Column(String(length=255))
|
||||
role = Column(String(length=255))
|
||||
name = Column(String(length=1024))
|
||||
uuid = Column(String(length=1024))
|
||||
desc = Column(String(length=2048))
|
||||
price = Column(Float)
|
||||
favorable = Column(String(length=1024))
|
||||
remark = Column(String(length=1024))
|
||||
otherPrompt = Column(String(length=1024))
|
||||
lang = Column(String(length=1024))
|
||||
type = Column(String(length=1024))
|
||||
created_at = Column(TIMESTAMP, server_default=func.now())
|
||||
updated_at = Column(TIMESTAMP, server_default=text('CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP'))
|
||||
18
app/models/result.py
Normal file
18
app/models/result.py
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
from sqlalchemy import Column, Integer, String, TIMESTAMP, func, text
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class Result(Base):
|
||||
__tablename__ = 'result'
|
||||
|
||||
id = Column(Integer, primary_key=True, nullable=False)
|
||||
prompt = Column(String())
|
||||
name = Column(String())
|
||||
uuid = Column(String())
|
||||
value = Column(String())
|
||||
lang = Column(String())
|
||||
created_at = Column(TIMESTAMP, server_default=func.now())
|
||||
updated_at = Column(TIMESTAMP, server_default=text('CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP'))
|
||||
5
app/schemas/__init__.py
Normal file
5
app/schemas/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
from .form import Form, FormCreate, FormUpdate
|
||||
from .result import Result, ResultCreate, ResultUpdate
|
||||
from .bench import Bench
|
||||
10
app/schemas/bench.py
Normal file
10
app/schemas/bench.py
Normal file
@@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.schemas import Form, Result
|
||||
|
||||
|
||||
class Bench(BaseModel):
|
||||
form: Form
|
||||
results: list[Result]
|
||||
44
app/schemas/form.py
Normal file
44
app/schemas/form.py
Normal file
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.v1 import validator
|
||||
|
||||
|
||||
class FormBase(BaseModel):
|
||||
base_prompt: str | None = None
|
||||
prompt: str | None = None
|
||||
p_choice: str | None = None
|
||||
role: str | None = None
|
||||
name: str | None = None
|
||||
uuid: str | None = None
|
||||
desc: str | None = None
|
||||
price: float | None = None
|
||||
favorable: str | None = None
|
||||
remark: str | None = None
|
||||
otherPrompt: str | None = None
|
||||
lang: str | None = None
|
||||
type: str | None = None
|
||||
|
||||
|
||||
class Form(FormBase):
|
||||
id: int
|
||||
created_at: datetime = datetime.now()
|
||||
updated_at: datetime = datetime.now()
|
||||
# custom input conversion for that field
|
||||
_normalize_datetimes = validator(
|
||||
"created_at", "updated_at",
|
||||
allow_reuse=True)(lambda v: v.timestamp())
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class FormCreate(FormBase):
|
||||
...
|
||||
|
||||
|
||||
class FormUpdate(FormBase):
|
||||
...
|
||||
34
app/schemas/result.py
Normal file
34
app/schemas/result.py
Normal file
@@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.v1 import validator
|
||||
|
||||
|
||||
class ResultBase(BaseModel):
|
||||
prompt: str | None = None
|
||||
name: str | None = None
|
||||
uuid: str | None = None
|
||||
value: str | None = None
|
||||
lang: str | None = None
|
||||
|
||||
|
||||
class Result(ResultBase):
|
||||
id: int
|
||||
created_at: datetime = datetime.now()
|
||||
updated_at: datetime = datetime.now()
|
||||
_normalize_datetimes = validator(
|
||||
"created_at", "updated_at",
|
||||
allow_reuse=True)(lambda v: v.timestamp())
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ResultCreate(ResultBase):
|
||||
...
|
||||
|
||||
|
||||
class ResultUpdate(ResultBase):
|
||||
...
|
||||
Reference in New Issue
Block a user