From bc168d08a9c359e1d2132ba0253cb2a002165b11 Mon Sep 17 00:00:00 2001 From: Wenjie Zhang Date: Fri, 2 May 2025 23:56:59 +0800 Subject: [PATCH 01/10] =?UTF-8?q?=E5=9F=BA=E6=9C=AC=E6=9D=83=E9=99=90?= =?UTF-8?q?=E6=8E=A7=E5=88=B6=E5=8A=9F=E8=83=BD=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 + requirements.txt | 3 +- server/db_manager.py | 10 + server/main.py | 102 +++++- server/models/user_model.py | 56 +++ server/routers/__init__.py | 4 +- server/routers/admin_router.py | 55 ++- server/routers/auth_router.py | 347 +++++++++++++++++++ server/routers/chat_router.py | 98 ++++-- server/routers/data_router.py | 36 +- server/utils/__init__.py | 1 + server/utils/auth_middleware.py | 120 +++++++ server/utils/auth_utils.py | 75 ++++ src/config/__init__.py | 2 + web/src/apis/admin_api.js | 314 +++++++++++++++++ web/src/apis/auth_api.js | 106 ++++++ web/src/apis/base.js | 142 ++++++++ web/src/apis/index.js | 33 ++ web/src/apis/public_api.js | 59 ++++ web/src/components/AgentChatComponent.vue | 29 +- web/src/components/ChatComponent.vue | 44 ++- web/src/components/TokenManagerComponent.vue | 54 +-- web/src/components/UserInfoComponent.vue | 139 ++++++++ web/src/layouts/AppLayout.vue | 13 +- web/src/router/index.js | 140 +++++--- web/src/stores/user.js | 232 +++++++++++++ web/src/views/AgentSingleView.vue | 210 +++++------ web/src/views/AgentView.vue | 159 ++++++--- web/src/views/DataBaseInfoView.vue | 81 ++--- web/src/views/DataBaseView.vue | 43 +-- web/src/views/GraphView.vue | 130 +++---- web/src/views/HomeView.vue | 44 ++- web/src/views/LoginView.vue | 303 ++++++++++++++++ web/src/views/SettingView.vue | 339 +++++++++++++++++- 34 files changed, 2994 insertions(+), 531 deletions(-) create mode 100644 server/models/user_model.py create mode 100644 server/routers/auth_router.py create mode 100644 server/utils/__init__.py create mode 100644 server/utils/auth_middleware.py create mode 100644 server/utils/auth_utils.py create mode 100644 web/src/apis/admin_api.js create mode 100644 web/src/apis/auth_api.js create mode 100644 web/src/apis/base.js create mode 100644 web/src/apis/index.js create mode 100644 web/src/apis/public_api.js create mode 100644 web/src/components/UserInfoComponent.vue create mode 100644 web/src/stores/user.js create mode 100644 web/src/views/LoginView.vue diff --git a/pyproject.toml b/pyproject.toml index 64cc709b1..7651724ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,9 +19,11 @@ dependencies = [ "openai>=1.76.0", "opencv-python-headless>=4.11.0.86", "paddleocr>=2.10.0", + "pyjwt>=2.8.0", "pymilvus>=2.5.8", "pymupdf>=1.25.5", "python-dotenv>=1.1.0", + "python-jose[cryptography]>=3.4.0", "python-multipart>=0.0.20", "pyyaml>=6.0.2", "qianfan>=0.4.12.3", diff --git a/requirements.txt b/requirements.txt index 790c4f8ab..e0a7a1374 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,4 +31,5 @@ langchain langsmith langgraph langchain-openai -langchain-community \ No newline at end of file +langchain-community +PyJWT>=2.10.1 \ No newline at end of file diff --git a/server/db_manager.py b/server/db_manager.py index 0a05bcb91..6b526f2c0 100644 --- a/server/db_manager.py +++ b/server/db_manager.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.declarative import declarative_base from server.models.token_model import Base, AgentToken +from server.models.user_model import User, OperationLog class DBManager: """数据库管理器""" @@ -32,6 +33,15 @@ def create_tables(self): """创建数据库表""" Base.metadata.create_all(self.engine) + def check_first_run(self): + """检查是否首次运行""" + session = self.get_session() + try: + # 检查是否有任何用户存在 + return session.query(User).count() == 0 + finally: + session.close() + def get_session(self): """获取数据库会话""" return self.Session() diff --git a/server/main.py b/server/main.py index 93277ccf1..5f383fc6e 100644 --- a/server/main.py +++ b/server/main.py @@ -1,8 +1,13 @@ import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, Request, HTTPException, status, Depends from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware + from server.routers import router +from server.utils.auth_middleware import get_current_user, is_public_path, is_admin_path +from server.models.user_model import User from src.utils.logging_config import logger @@ -18,6 +23,101 @@ allow_headers=["*"], ) +# 鉴权中间件 +class AuthMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # 获取请求路径 + path = request.url.path + + # 检查是否为公开路径,公开路径无需身份验证 + if is_public_path(path): + return await call_next(request) + + # 注意:前端代理已经去掉了 /api 前缀,例如 /api/chat 变成了 /chat + # 判断是否需要验证的API请求,包括聊天、数据、工具等 + is_api_path = ( + path.startswith("/chat") or + path.startswith("/data") or + path.startswith("/admin") or + path.startswith("/auth") and not is_public_path(path) + ) + + if not is_api_path: + # 非API路径,可能是前端路由或静态资源 + return await call_next(request) + + # 提取Authorization头 + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "请先登录"}, + headers={"WWW-Authenticate": "Bearer"} + ) + + # 获取token + token = auth_header.split("Bearer ")[1] + + # 添加token到请求状态,后续路由可以直接使用 + request.state.token = token + + # 检查是否需要管理员权限 + if is_admin_path(path): + # 尝试获取数据库会话 + try: + from server.db_manager import db_manager + from server.utils.auth_utils import AuthUtils + + db = db_manager.get_session() + try: + # 验证token并获取用户信息 + payload = AuthUtils.verify_access_token(token) + user_id = payload.get("sub") + + if not user_id: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "无效的用户标识"}, + headers={"WWW-Authenticate": "Bearer"} + ) + + # 查询用户信息 + from server.models.user_model import User + user = db.query(User).filter(User.id == user_id).first() + + if not user: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "用户不存在"}, + headers={"WWW-Authenticate": "Bearer"} + ) + + # 检查管理员权限 + if user.role not in ["admin", "superadmin"]: + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"detail": "需要管理员权限"} + ) + + # 将用户信息添加到请求状态 + request.state.user = user + + finally: + db.close() + + except Exception as e: + logger.error(f"验证管理员权限出错: {e}") + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "验证用户权限出错"}, + headers={"WWW-Authenticate": "Bearer"} + ) + + # 继续处理请求 + return await call_next(request) + +# 添加鉴权中间件 +app.add_middleware(AuthMiddleware) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=5050, threads=10, workers=10, reload=True) diff --git a/server/models/user_model.py b/server/models/user_model.py new file mode 100644 index 000000000..fa0f4dfbe --- /dev/null +++ b/server/models/user_model.py @@ -0,0 +1,56 @@ +from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Text +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.sql import func +from sqlalchemy.orm import relationship + +from server.models.token_model import Base + +class User(Base): + """用户模型""" + __tablename__ = 'users' + + id = Column(Integer, primary_key=True, autoincrement=True) + username = Column(String, nullable=False, unique=True, index=True) + password_hash = Column(String, nullable=False) + role = Column(String, nullable=False, default='user') # 角色: superadmin, admin, user + created_at = Column(DateTime, default=func.now()) + last_login = Column(DateTime, nullable=True) + + # 关联操作日志 + operation_logs = relationship("OperationLog", back_populates="user") + + def to_dict(self, include_password=False): + result = { + "id": self.id, + "username": self.username, + "role": self.role, + "created_at": self.created_at.isoformat() if self.created_at else None, + "last_login": self.last_login.isoformat() if self.last_login else None + } + if include_password: + result["password_hash"] = self.password_hash + return result + +class OperationLog(Base): + """操作日志模型""" + __tablename__ = 'operation_logs' + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey('users.id'), nullable=False) + operation = Column(String, nullable=False) + details = Column(Text, nullable=True) + ip_address = Column(String, nullable=True) + timestamp = Column(DateTime, default=func.now()) + + # 关联用户 + user = relationship("User", back_populates="operation_logs") + + def to_dict(self): + return { + "id": self.id, + "user_id": self.user_id, + "operation": self.operation, + "details": self.details, + "ip_address": self.ip_address, + "timestamp": self.timestamp.isoformat() if self.timestamp else None + } \ No newline at end of file diff --git a/server/routers/__init__.py b/server/routers/__init__.py index 6878d9347..b11f92745 100644 --- a/server/routers/__init__.py +++ b/server/routers/__init__.py @@ -2,12 +2,12 @@ from server.routers.chat_router import chat from server.routers.data_router import data from server.routers.base_router import base -from server.routers.tool_router import tool from server.routers.admin_router import admin +from server.routers.auth_router import auth router = APIRouter() router.include_router(base) router.include_router(chat) router.include_router(data) -router.include_router(tool) router.include_router(admin) +router.include_router(auth) diff --git a/server/routers/admin_router.py b/server/routers/admin_router.py index 9471b6bc3..1556e00eb 100644 --- a/server/routers/admin_router.py +++ b/server/routers/admin_router.py @@ -1,23 +1,18 @@ import secrets import string -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, HTTPException, Query, Request from pydantic import BaseModel from typing import List, Optional from sqlalchemy.orm import Session from server.db_manager import db_manager from server.models.token_model import AgentToken +from server.models.user_model import User, OperationLog +from server.utils.auth_middleware import get_db, get_current_user, get_admin_user, oauth2_scheme +from server.routers.auth_router import log_operation admin = APIRouter(prefix="/admin", tags=["admin"]) -# 依赖项:获取数据库会话 -def get_db(): - db = db_manager.get_session() - try: - yield db - finally: - db.close() - # 请求和响应模型 class TokenCreate(BaseModel): agent_id: str @@ -42,9 +37,10 @@ def generate_token(length=32): @admin.get("/tokens", response_model=List[TokenResponse]) async def get_agent_tokens( agent_id: Optional[str] = Query(None), + current_user: User = Depends(get_admin_user), db: Session = Depends(get_db) ): - """获取智能体的token列表""" + """获取智能体的token列表(需要管理员权限)""" query = db.query(AgentToken) if agent_id: query = query.filter(AgentToken.agent_id == agent_id) @@ -54,9 +50,11 @@ async def get_agent_tokens( @admin.post("/tokens", response_model=TokenResponse) async def create_token( token_data: TokenCreate, + request: Request, + current_user: User = Depends(get_admin_user), db: Session = Depends(get_db) ): - """创建新的token""" + """创建新的token(需要管理员权限)""" # 生成随机token token_value = generate_token() @@ -72,15 +70,38 @@ async def create_token( db.commit() db.refresh(new_token) + # 记录操作 + log_operation( + db, + current_user.id, + "创建令牌", + f"为智能体 {token_data.agent_id} 创建访问令牌: {token_data.name}", + request + ) + return new_token.to_dict() @admin.delete("/tokens/{token_id}", response_model=dict) -async def delete_token(token_id: int, db: Session = Depends(get_db)): - """删除token""" +async def delete_token( + token_id: int, + request: Request, + current_user: User = Depends(get_admin_user), + db: Session = Depends(get_db) +): + """删除token(需要管理员权限)""" token = db.query(AgentToken).filter(AgentToken.id == token_id).first() if not token: raise HTTPException(status_code=404, detail="Token not found") + # 记录操作信息 + log_operation( + db, + current_user.id, + "删除令牌", + f"删除令牌ID: {token_id}, 智能体: {token.agent_id}, 名称: {token.name}", + request + ) + db.delete(token) db.commit() @@ -89,15 +110,17 @@ async def delete_token(token_id: int, db: Session = Depends(get_db)): @admin.post("/verify_token") async def verify_agent_token( token_data: TokenVerify, + token: Optional[str] = Depends(oauth2_scheme), db: Session = Depends(get_db) ): - """验证智能体访问令牌""" - token = db.query(AgentToken).filter( + """验证智能体访问令牌(所有用户都可访问)""" + # 查找令牌 + agent_token = db.query(AgentToken).filter( AgentToken.agent_id == token_data.agent_id, AgentToken.token == token_data.token ).first() - if not token: + if not agent_token: raise HTTPException(status_code=401, detail="Invalid token") return {"success": True, "message": "Token verified"} \ No newline at end of file diff --git a/server/routers/auth_router.py b/server/routers/auth_router.py new file mode 100644 index 000000000..08654bf81 --- /dev/null +++ b/server/routers/auth_router.py @@ -0,0 +1,347 @@ +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.security import OAuth2PasswordRequestForm +from pydantic import BaseModel +from typing import List, Optional +from sqlalchemy.orm import Session +from datetime import datetime, timedelta + +from server.db_manager import db_manager +from server.models.user_model import User, OperationLog +from server.utils.auth_utils import AuthUtils +from server.utils.auth_middleware import get_db, get_current_user, get_admin_user, get_superadmin_user, oauth2_scheme + +# 创建路由器 +auth = APIRouter(prefix="/auth", tags=["auth"]) + +# 请求和响应模型 +class Token(BaseModel): + access_token: str + token_type: str + user_id: int + username: str + role: str + +class UserCreate(BaseModel): + username: str + password: str + role: str = "user" + +class UserUpdate(BaseModel): + username: Optional[str] = None + password: Optional[str] = None + role: Optional[str] = None + +class UserResponse(BaseModel): + id: int + username: str + role: str + created_at: str + last_login: Optional[str] = None + +class InitializeAdmin(BaseModel): + username: str + password: str + +# 记录操作日志 +def log_operation(db: Session, user_id: int, operation: str, details: str = None, request: Request = None): + ip_address = None + if request: + ip_address = request.client.host if request.client else None + + log = OperationLog( + user_id=user_id, + operation=operation, + details=details, + ip_address=ip_address + ) + db.add(log) + db.commit() + +# 路由:登录获取令牌 +@auth.post("/token", response_model=Token) +async def login_for_access_token( + form_data: OAuth2PasswordRequestForm = Depends(), + db: Session = Depends(get_db) +): + # 查找用户 + user = db.query(User).filter(User.username == form_data.username).first() + + # 验证用户存在且密码正确 + if not user or not AuthUtils.verify_password(user.password_hash, form_data.password): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="用户名或密码错误", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # 更新最后登录时间 + user.last_login = datetime.now() + db.commit() + + # 生成访问令牌 + token_data = {"sub": str(user.id)} + access_token = AuthUtils.create_access_token(token_data) + + # 记录登录操作 + log_operation(db, user.id, "登录") + + return { + "access_token": access_token, + "token_type": "bearer", + "user_id": user.id, + "username": user.username, + "role": user.role + } + +# 路由:校验是否需要初始化管理员 +@auth.get("/check-first-run") +async def check_first_run(): + is_first_run = db_manager.check_first_run() + return {"first_run": is_first_run} + +# 路由:初始化管理员账户 +@auth.post("/initialize", response_model=Token) +async def initialize_admin( + admin_data: InitializeAdmin, + db: Session = Depends(get_db) +): + # 检查是否是首次运行 + if not db_manager.check_first_run(): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="系统已经初始化,无法再次创建初始管理员", + ) + + # 创建管理员账户 + hashed_password = AuthUtils.hash_password(admin_data.password) + + new_admin = User( + username=admin_data.username, + password_hash=hashed_password, + role="superadmin", + last_login=datetime.now() + ) + + db.add(new_admin) + db.commit() + db.refresh(new_admin) + + # 生成访问令牌 + token_data = {"sub": str(new_admin.id)} + access_token = AuthUtils.create_access_token(token_data) + + # 记录操作 + log_operation(db, new_admin.id, "系统初始化", "创建超级管理员账户") + + return { + "access_token": access_token, + "token_type": "bearer", + "user_id": new_admin.id, + "username": new_admin.username, + "role": new_admin.role + } + +# 路由:获取当前用户信息 +@auth.get("/me", response_model=UserResponse) +async def read_users_me(current_user: User = Depends(get_current_user)): + return current_user.to_dict() + +# 路由:创建新用户(管理员权限) +@auth.post("/users", response_model=UserResponse) +async def create_user( + user_data: UserCreate, + request: Request, + current_user: User = Depends(get_admin_user), + db: Session = Depends(get_db) +): + # 检查用户名是否已存在 + existing_user = db.query(User).filter(User.username == user_data.username).first() + if existing_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="用户名已存在", + ) + + # 创建新用户 + hashed_password = AuthUtils.hash_password(user_data.password) + + # 检查角色权限 + # 超级管理员可以创建任何类型的用户 + if user_data.role == "superadmin" and current_user.role != "superadmin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="只有超级管理员才能创建超级管理员账户", + ) + + # 管理员只能创建普通用户 + if current_user.role == "admin" and user_data.role != "user": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="管理员只能创建普通用户账户", + ) + + new_user = User( + username=user_data.username, + password_hash=hashed_password, + role=user_data.role + ) + + db.add(new_user) + db.commit() + db.refresh(new_user) + + # 记录操作 + log_operation( + db, + current_user.id, + "创建用户", + f"创建用户: {user_data.username}, 角色: {user_data.role}", + request + ) + + return new_user.to_dict() + +# 路由:获取所有用户(管理员权限) +@auth.get("/users", response_model=List[UserResponse]) +async def read_users( + skip: int = 0, + limit: int = 100, + current_user: User = Depends(get_admin_user), + db: Session = Depends(get_db) +): + users = db.query(User).offset(skip).limit(limit).all() + return [user.to_dict() for user in users] + +# 路由:获取特定用户信息(管理员权限) +@auth.get("/users/{user_id}", response_model=UserResponse) +async def read_user( + user_id: int, + current_user: User = Depends(get_admin_user), + db: Session = Depends(get_db) +): + user = db.query(User).filter(User.id == user_id).first() + if user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="用户不存在", + ) + return user.to_dict() + +# 路由:更新用户信息(管理员权限) +@auth.put("/users/{user_id}", response_model=UserResponse) +async def update_user( + user_id: int, + user_data: UserUpdate, + request: Request, + current_user: User = Depends(get_admin_user), + db: Session = Depends(get_db) +): + user = db.query(User).filter(User.id == user_id).first() + if user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="用户不存在", + ) + + # 检查权限 + if user.role == "superadmin" and current_user.role != "superadmin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="只有超级管理员才能修改超级管理员账户", + ) + + # 超级管理员账户不能被降级(只能由其他超级管理员修改) + if user.role == "superadmin" and user_data.role and user_data.role != "superadmin" and current_user.id != user.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="不能降级超级管理员账户", + ) + + # 更新信息 + update_details = [] + + if user_data.username is not None: + # 检查用户名是否已被其他用户使用 + existing_user = db.query(User).filter(User.username == user_data.username, User.id != user_id).first() + if existing_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="用户名已存在", + ) + user.username = user_data.username + update_details.append(f"用户名: {user_data.username}") + + if user_data.password is not None: + user.password_hash = AuthUtils.hash_password(user_data.password) + update_details.append("密码已更新") + + if user_data.role is not None: + user.role = user_data.role + update_details.append(f"角色: {user_data.role}") + + db.commit() + + # 记录操作 + log_operation( + db, + current_user.id, + "更新用户", + f"更新用户ID {user_id}: {', '.join(update_details)}", + request + ) + + return user.to_dict() + +# 路由:删除用户(管理员权限) +@auth.delete("/users/{user_id}", response_model=dict) +async def delete_user( + user_id: int, + request: Request, + current_user: User = Depends(get_admin_user), + db: Session = Depends(get_db) +): + user = db.query(User).filter(User.id == user_id).first() + if user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="用户不存在", + ) + + # 检查权限 + if user.role == "superadmin": + # 只有超级管理员可以删除超级管理员 + if current_user.role != "superadmin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="只有超级管理员才能删除超级管理员账户", + ) + + # 检查是否是最后一个超级管理员 + superadmin_count = db.query(User).filter(User.role == "superadmin").count() + if superadmin_count <= 1: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="不能删除最后一个超级管理员账户", + ) + + # 不能删除自己的账户 + if user.id == current_user.id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="不能删除自己的账户", + ) + + # 记录操作 + log_operation( + db, + current_user.id, + "删除用户", + f"删除用户: {user.username}, ID: {user.id}, 角色: {user.role}", + request + ) + + # 删除用户 + db.delete(user) + db.commit() + + return {"success": True, "message": "用户已删除"} \ No newline at end of file diff --git a/server/routers/chat_router.py b/server/routers/chat_router.py index 9c5fb82d3..8fe2fb23f 100644 --- a/server/routers/chat_router.py +++ b/server/routers/chat_router.py @@ -13,11 +13,54 @@ from src.models import select_model from src.utils.logging_config import logger from src.agents.tools_factory import get_all_tools +from server.routers.auth_router import get_admin_user +from server.utils.auth_middleware import get_required_user +from server.models.user_model import User chat = APIRouter(prefix="/chat") +@chat.get("/default_agent") +async def get_default_agent(current_user: User = Depends(get_required_user)): + """获取默认智能体ID(需要登录)""" + try: + default_agent_id = config.default_agent_id + # 如果没有设置默认智能体,尝试获取第一个可用的智能体 + if not default_agent_id: + agents = [agent.get_info() for agent in agent_manager.agents.values()] + if agents: + default_agent_id = agents[0].get("name", "") + + return {"default_agent_id": default_agent_id} + except Exception as e: + logger.error(f"获取默认智能体出错: {e}") + raise HTTPException(status_code=500, detail=f"获取默认智能体出错: {str(e)}") + +@chat.post("/set_default_agent") +async def set_default_agent(agent_id: str = Body(..., embed=True), current_user = Depends(get_admin_user)): + """设置默认智能体ID (仅管理员)""" + try: + # 验证智能体是否存在 + agents = [agent.get_info() for agent in agent_manager.agents.values()] + agent_ids = [agent.get("name", "") for agent in agents] + + if agent_id not in agent_ids: + raise HTTPException(status_code=404, detail=f"智能体 {agent_id} 不存在") + + # 设置默认智能体ID + config.default_agent_id = agent_id + # 保存配置 + config.save() + + return {"success": True, "default_agent_id": agent_id} + except HTTPException as he: + raise he + except Exception as e: + logger.error(f"设置默认智能体出错: {e}") + raise HTTPException(status_code=500, detail=f"设置默认智能体出错: {str(e)}") + @chat.get("/") -async def chat_get(): +async def chat_get(current_user: User = Depends(get_required_user)): + """聊天服务健康检查(需要登录)""" return "Chat Get!" @chat.post("/") @@ -25,29 +68,9 @@ async def chat_post( query: str = Body(...), meta: dict = Body(None), history: list[dict] | None = Body(None), - thread_id: str | None = Body(None)): - """处理聊天请求的主要端点。 - Args: - query: 用户的输入查询文本 - meta: 包含请求元数据的字典,可以包含以下字段: - - use_web: 是否使用网络搜索 - - use_graph: 是否使用知识图谱 - - db_id: 数据库ID - - history_round: 历史对话轮数限制 - - system_prompt: 系统提示词(str,不含变量) - history: 对话历史记录列表 - thread_id: 对话线程ID - Returns: - StreamingResponse: 返回一个流式响应,包含以下状态: - - searching: 正在搜索知识库 - - generating: 正在生成回答 - - reasoning: 正在推理 - - loading: 正在加载回答 - - finished: 回答完成 - - error: 发生错误 - Raises: - HTTPException: 当检索器或模型发生错误时抛出 - """ + thread_id: str | None = Body(None), + current_user: User = Depends(get_required_user)): + """处理聊天请求的主要端点(需要登录)""" model = select_model() meta["server_model_name"] = model.model_name @@ -117,7 +140,8 @@ def generate_response(): return StreamingResponse(generate_response(), media_type='application/json') @chat.post("/call") -async def call(query: str = Body(...), meta: dict = Body(None)): +async def call(query: str = Body(...), meta: dict = Body(None), current_user: User = Depends(get_required_user)): + """调用模型进行简单问答(需要登录)""" meta = meta or {} model = select_model(model_provider=meta.get("model_provider"), model_name=meta.get("model_name")) async def predict_async(query): @@ -130,7 +154,8 @@ async def predict_async(query): return {"response": response.content} @chat.post("/call_lite") -async def call_lite(query: str = Body(...), meta: dict = Body(None)): +async def call_lite(query: str = Body(...), meta: dict = Body(None), current_user: User = Depends(get_required_user)): + """使用轻量版模型进行问答(需要登录)""" meta = meta or {} async def predict_async(query): loop = asyncio.get_event_loop() @@ -145,7 +170,8 @@ async def predict_async(query): return {"response": response.content} @chat.get("/agent") -async def get_agent(): +async def get_agent(current_user: User = Depends(get_required_user)): + """获取所有可用智能体(需要登录)""" agents = [agent.get_info() for agent in agent_manager.agents.values()] return {"agents": agents} @@ -154,12 +180,14 @@ def chat_agent(agent_name: str, query: str = Body(...), history: list = Body(...), config: dict = Body({}), - meta: dict = Body({})): + meta: dict = Body({}), + current_user: User = Depends(get_required_user)): + """使用特定智能体进行对话(需要登录)""" meta.update({ "query": query, "agent_name": agent_name, - "server_model_name": config.get("model", agent_name) , + "server_model_name": config.get("model", agent_name), "thread_id": config.get("thread_id"), }) @@ -223,19 +251,19 @@ def stream_messages(): return StreamingResponse(stream_messages(), media_type='application/json') @chat.get("/models") -async def get_chat_models(model_provider: str): - """获取指定模型提供商的模型列表""" +async def get_chat_models(model_provider: str, current_user: User = Depends(get_admin_user)): + """获取指定模型提供商的模型列表(需要登录)""" model = select_model(model_provider=model_provider) return {"models": model.get_models()} @chat.post("/models/update") -async def update_chat_models(model_provider: str, model_names: list[str]): - """更新指定模型提供商的模型列表""" +async def update_chat_models(model_provider: str, model_names: list[str], current_user = Depends(get_admin_user)): + """更新指定模型提供商的模型列表 (仅管理员)""" config.model_names[model_provider]["models"] = model_names config._save_models_to_file() return {"models": config.model_names[model_provider]["models"]} @chat.get("/tools") -async def get_tools(): - """获取所有工具""" +async def get_tools(current_user: User = Depends(get_admin_user)): + """获取所有可用工具(需要登录)""" return {"tools": list(get_all_tools().keys())} diff --git a/server/routers/data_router.py b/server/routers/data_router.py index 699ed9681..3b5534080 100644 --- a/server/routers/data_router.py +++ b/server/routers/data_router.py @@ -6,12 +6,14 @@ from src.utils import logger, hashstr from src import executor, retriever, config, knowledge_base, graph_base +from server.utils.auth_middleware import get_admin_user +from server.models.user_model import User data = APIRouter(prefix="/data") @data.get("/") -async def get_databases(): +async def get_databases(current_user: User = Depends(get_admin_user)): try: database = knowledge_base.get_databases() except Exception as e: @@ -23,7 +25,8 @@ async def get_databases(): async def create_database( database_name: str = Body(...), description: str = Body(...), - dimension: Optional[int] = Body(None) + dimension: Optional[int] = Body(None), + current_user: User = Depends(get_admin_user) ): logger.debug(f"Create database {database_name}") try: @@ -38,25 +41,25 @@ async def create_database( return database_info @data.delete("/") -async def delete_database(db_id): +async def delete_database(db_id, current_user: User = Depends(get_admin_user)): logger.debug(f"Delete database {db_id}") knowledge_base.delete_database(db_id) return {"message": "删除成功"} @data.post("/query-test") -async def query_test(query: str = Body(...), meta: dict = Body(...)): +async def query_test(query: str = Body(...), meta: dict = Body(...), current_user: User = Depends(get_admin_user)): logger.debug(f"Query test in {meta}: {query}") result = retriever.query_knowledgebase(query, history=None, refs={"meta": meta}) return result @data.post("/file-to-chunk") -async def file_to_chunk(files: List[str] = Body(...), params: dict = Body(...)): +async def file_to_chunk(files: List[str] = Body(...), params: dict = Body(...), current_user: User = Depends(get_admin_user)): logger.debug(f"File to chunk: {files}") result = knowledge_base.file_to_chunk(files, params=params) return result @data.post("/add-by-file") -async def create_document_by_file(db_id: str = Body(...), files: List[str] = Body(...)): +async def create_document_by_file(db_id: str = Body(...), files: List[str] = Body(...), current_user: User = Depends(get_admin_user)): logger.debug(f"Add document in {db_id} by file: {files}") try: # 使用线程池执行耗时操作 @@ -71,7 +74,7 @@ async def create_document_by_file(db_id: str = Body(...), files: List[str] = Bod return {"message": f"添加文件失败: {e}", "status": "failed"} @data.post("/add-by-chunks") -async def add_by_chunks(db_id: str = Body(...), file_chunks: dict = Body(...)): +async def add_by_chunks(db_id: str = Body(...), file_chunks: dict = Body(...), current_user: User = Depends(get_admin_user)): # logger.debug(f"Add chunks in {db_id}: {len(file_chunks)} chunks") try: loop = asyncio.get_event_loop() @@ -85,7 +88,7 @@ async def add_by_chunks(db_id: str = Body(...), file_chunks: dict = Body(...)): return {"message": f"添加分块失败: {e}", "status": "failed"} @data.get("/info") -async def get_database_info(db_id: str): +async def get_database_info(db_id: str, current_user: User = Depends(get_admin_user)): # logger.debug(f"Get database {db_id} info") database = knowledge_base.get_database_info(db_id) if database is None: @@ -93,13 +96,13 @@ async def get_database_info(db_id: str): return database @data.delete("/document") -async def delete_document(db_id: str = Body(...), file_id: str = Body(...)): +async def delete_document(db_id: str = Body(...), file_id: str = Body(...), current_user: User = Depends(get_admin_user)): logger.debug(f"DELETE document {file_id} info in {db_id}") knowledge_base.delete_file(db_id, file_id) return {"message": "删除成功"} @data.get("/document") -async def get_document_info(db_id: str, file_id: str): +async def get_document_info(db_id: str, file_id: str, current_user: User = Depends(get_admin_user)): logger.debug(f"GET document {file_id} info in {db_id}") try: @@ -113,7 +116,8 @@ async def get_document_info(db_id: str, file_id: str): @data.post("/upload") async def upload_file( file: UploadFile = File(...), - db_id: Optional[str] = Query(None) + db_id: Optional[str] = Query(None), + current_user: User = Depends(get_admin_user) ): if not file.filename: raise HTTPException(status_code=400, detail="No selected file") @@ -135,14 +139,14 @@ async def upload_file( return {"message": "File successfully uploaded", "file_path": file_path, "db_id": db_id} @data.get("/graph") -async def get_graph_info(): +async def get_graph_info(current_user: User = Depends(get_admin_user)): graph_info = graph_base.get_graph_info() if graph_info is None: raise HTTPException(status_code=400, detail="图数据库获取出错") return graph_info @data.post("/graph/index-nodes") -async def index_nodes(data: dict = Body(default={})): +async def index_nodes(data: dict = Body(default={}), current_user: User = Depends(get_admin_user)): if not graph_base.is_running(): raise HTTPException(status_code=400, detail="图数据库未启动") @@ -155,12 +159,12 @@ async def index_nodes(data: dict = Body(default={})): return {"status": "success", "message": f"已成功为{count}个节点添加嵌入向量", "indexed_count": count} @data.get("/graph/node") -async def get_graph_node(entity_name: str): +async def get_graph_node(entity_name: str, current_user: User = Depends(get_admin_user)): result = graph_base.query_node(entity_name=entity_name) return {"result": graph_base.format_query_result_to_graph(result), "message": "success"} @data.get("/graph/nodes") -async def get_graph_nodes(kgdb_name: str, num: int): +async def get_graph_nodes(kgdb_name: str, num: int, current_user: User = Depends(get_admin_user)): if not config.enable_knowledge_graph: raise HTTPException(status_code=400, detail="Knowledge graph is not enabled") @@ -169,7 +173,7 @@ async def get_graph_nodes(kgdb_name: str, num: int): return {"result": graph_base.format_general_results(result), "message": "success"} @data.post("/graph/add-by-jsonl") -async def add_graph_entity(file_path: str = Body(...), kgdb_name: Optional[str] = Body(None)): +async def add_graph_entity(file_path: str = Body(...), kgdb_name: Optional[str] = Body(None), current_user: User = Depends(get_admin_user)): if not config.enable_knowledge_graph: return {"message": "知识图谱未启用", "status": "failed"} diff --git a/server/utils/__init__.py b/server/utils/__init__.py new file mode 100644 index 000000000..4901b7189 --- /dev/null +++ b/server/utils/__init__.py @@ -0,0 +1 @@ +# utils包初始化文件 \ No newline at end of file diff --git a/server/utils/auth_middleware.py b/server/utils/auth_middleware.py new file mode 100644 index 000000000..692f872fb --- /dev/null +++ b/server/utils/auth_middleware.py @@ -0,0 +1,120 @@ +from typing import Optional, List +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from sqlalchemy.orm import Session +from jose import JWTError, jwt +import re + +from server.db_manager import db_manager +from server.models.user_model import User +from server.utils.auth_utils import AuthUtils + +# 定义OAuth2密码承载器,指定token URL +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token", auto_error=False) + +# 公开路径列表,无需登录即可访问 +PUBLIC_PATHS = [ + r"^/auth/token$", # 登录 + r"^/auth/check-first-run$", # 检查是否首次运行 + r"^/auth/initialize$", # 初始化系统 + r"^/docs$", r"^/redoc$", r"^/openapi.json$", # API文档 + r"^/static/.*$", # 静态资源 + r"^/assets/.*$", # 前端资源文件 + r"^/$", # 根路径(登录页) + r"^/login$", # 登录页面 + r"^/home$", # 首页 + r"^/home/.*$", # 首页下的所有路径 + r"^/favicon\.ico$", # 网站图标 + r"^/_nuxt/.*$", # Nuxt.js生成的资源文件 + r"^/js/.*$", # JavaScript文件 + r"^/css/.*$", # CSS文件 + r"^/img/.*$" # 图片文件 +] + +# 获取数据库会话 +def get_db(): + db = db_manager.get_session() + try: + yield db + finally: + db.close() + +# 获取当前用户 +async def get_current_user(token: Optional[str] = Depends(oauth2_scheme), db: Session = Depends(get_db)): + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="无效的凭证", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # 允许无token访问公开路径 + if token is None: + return None + + try: + # 验证token + payload = AuthUtils.verify_access_token(token) + user_id = payload.get("sub") + if user_id is None: + raise credentials_exception + except JWTError: + raise credentials_exception + + # 查找用户 + user = db.query(User).filter(User.id == user_id).first() + if user is None: + raise credentials_exception + + return user + +# 获取已登录用户(抛出401如果未登录) +async def get_required_user(user: Optional[User] = Depends(get_current_user)): + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="请登录后再访问", + headers={"WWW-Authenticate": "Bearer"}, + ) + return user + +# 获取管理员用户 +async def get_admin_user(current_user: User = Depends(get_required_user)): + if current_user.role not in ["admin", "superadmin"]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="需要管理员权限", + ) + return current_user + +# 获取超级管理员用户 +async def get_superadmin_user(current_user: User = Depends(get_required_user)): + if current_user.role != "superadmin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="需要超级管理员权限", + ) + return current_user + +# 检查路径是否为公开路径 +def is_public_path(path: str) -> bool: + path = path.rstrip('/') # 去除尾部斜杠以便于匹配 + for pattern in PUBLIC_PATHS: + if re.match(pattern, path): + return True + return False + +# 路径是否需要管理员权限 +ADMIN_PATHS = [ + r"^/admin/.*$", # 管理员接口 + r"^/data/.*$", # 数据操作接口,所有数据操作都需要管理员权限 + r"^/chat/set_default_agent$", # 设置默认智能体 + r"^/chat/models/update$" # 更新模型列表 +] + +# 检查路径是否需要管理员权限 +def is_admin_path(path: str) -> bool: + path = path.rstrip('/') # 去除尾部斜杠以便于匹配 + for pattern in ADMIN_PATHS: + if re.match(pattern, path): + return True + return False \ No newline at end of file diff --git a/server/utils/auth_utils.py b/server/utils/auth_utils.py new file mode 100644 index 000000000..18ce903a3 --- /dev/null +++ b/server/utils/auth_utils.py @@ -0,0 +1,75 @@ +import hashlib +import os +import jwt +from datetime import datetime, timedelta +from typing import Optional, Dict, Any + +# JWT配置 +JWT_SECRET_KEY = os.environ.get("JWT_SECRET_KEY", "yuxi_know_secure_key") +JWT_ALGORITHM = "HS256" +JWT_EXPIRATION = 24 * 60 * 60 # 24小时过期 + +class AuthUtils: + """认证工具类""" + + @staticmethod + def hash_password(password: str) -> str: + """使用SHA-256哈希密码""" + # 生成盐 + salt = os.urandom(32).hex() + # 哈希密码 + hashed = hashlib.sha256((password + salt).encode()).hexdigest() + # 返回格式: "哈希值:盐" + return f"{hashed}:{salt}" + + @staticmethod + def verify_password(stored_password: str, provided_password: str) -> bool: + """验证密码""" + # 分离哈希值和盐 + if ":" not in stored_password: + return False + + hashed, salt = stored_password.split(":") + + # 使用相同的盐哈希提供的密码 + check_hash = hashlib.sha256((provided_password + salt).encode()).hexdigest() + + # 比较哈希值 + return hashed == check_hash + + @staticmethod + def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: + """创建JWT访问令牌""" + to_encode = data.copy() + + # 设置过期时间 + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(seconds=JWT_EXPIRATION) + + to_encode.update({"exp": expire}) + + # 编码JWT + encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) + return encoded_jwt + + @staticmethod + def decode_token(token: str) -> Optional[Dict[str, Any]]: + """解码验证JWT令牌""" + try: + payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) + return payload + except jwt.PyJWTError: + return None + + @staticmethod + def verify_access_token(token: str) -> Dict[str, Any]: + """验证访问令牌,如果无效则抛出异常""" + try: + payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) + return payload + except jwt.ExpiredSignatureError: + raise ValueError("令牌已过期") + except jwt.InvalidTokenError: + raise ValueError("无效的令牌") \ No newline at end of file diff --git a/src/config/__init__.py b/src/config/__init__.py index b0df13178..99969c385 100644 --- a/src/config/__init__.py +++ b/src/config/__init__.py @@ -47,6 +47,8 @@ def __init__(self): self.add_item("enable_knowledge_base", default=False, des="是否开启知识库") self.add_item("enable_knowledge_graph", default=False, des="是否开启知识图谱") self.add_item("enable_web_search", default=False, des="是否开启网页搜索(注:现阶段会根据 TAVILY_API_KEY 自动开启,无法手动配置,将会在下个版本移除此配置项)") + # 默认智能体配置 + self.add_item("default_agent_id", default="", des="默认智能体ID") # 模型配置 ## 注意这里是模型名,而不是具体的模型路径,默认使用 HuggingFace 的路径 ## 如果需要自定义本地模型路径,则在 src/.env 中配置 MODEL_DIR diff --git a/web/src/apis/admin_api.js b/web/src/apis/admin_api.js new file mode 100644 index 000000000..5b4177cea --- /dev/null +++ b/web/src/apis/admin_api.js @@ -0,0 +1,314 @@ +import { apiGet, apiPost, apiPut, apiDelete } from './base' +import { useUserStore } from '@/stores/user' + +/** + * 管理员API模块 + * 只有管理员和超级管理员可以访问的API + * 权限要求: admin 或 superadmin + * + * 注意: 请确保在使用这些API之前检查用户是否具有管理员权限 + */ + +// 检查当前用户是否有管理员权限 +const checkAdminPermission = () => { + const userStore = useUserStore() + if (!userStore.isAdmin) { + throw new Error('需要管理员权限') + } + return true +} + +// 检查当前用户是否有超级管理员权限 +const checkSuperAdminPermission = () => { + const userStore = useUserStore() + if (!userStore.isSuperAdmin) { + throw new Error('需要超级管理员权限') + } + return true +} + +// 用户管理API +export const userManagementApi = { + /** + * 获取用户列表 + * @returns {Promise} - 用户列表 + */ + getUsers: async () => { + checkAdminPermission() + return apiGet('/api/auth/users', {}, true) + }, + + /** + * 创建新用户 + * @param {Object} userData - 用户数据 + * @returns {Promise} - 创建结果 + */ + createUser: async (userData) => { + checkAdminPermission() + return apiPost('/api/auth/users', userData, {}, true) + }, + + /** + * 更新用户 + * @param {number} userId - 用户ID + * @param {Object} userData - 用户数据 + * @returns {Promise} - 更新结果 + */ + updateUser: async (userId, userData) => { + checkAdminPermission() + return apiPut(`/api/auth/users/${userId}`, userData, {}, true) + }, + + /** + * 删除用户 + * @param {number} userId - 用户ID + * @returns {Promise} - 删除结果 + */ + deleteUser: async (userId) => { + checkAdminPermission() + return apiDelete(`/api/auth/users/${userId}`, {}, true) + }, +} + +// 令牌管理API +export const tokenApi = { + /** + * 获取令牌列表 + * @param {string} agentId - 智能体ID + * @returns {Promise} - 令牌列表 + */ + getTokens: async (agentId) => { + checkAdminPermission() + return apiGet(`/api/admin/tokens?agent_id=${agentId}`, {}, true) + }, + + /** + * 创建新令牌 + * @param {string} agentId - 智能体ID + * @param {string} name - 令牌名称 + * @returns {Promise} - 创建结果 + */ + createToken: async (agentId, name) => { + checkAdminPermission() + return apiPost('/api/admin/tokens', { agent_id: agentId, name }, {}, true) + }, + + /** + * 删除令牌 + * @param {string} tokenId - 令牌ID + * @returns {Promise} - 删除结果 + */ + deleteToken: async (tokenId) => { + checkAdminPermission() + return apiDelete(`/api/admin/tokens/${tokenId}`, {}, true) + }, +} + +// 知识库管理API +export const knowledgeBaseApi = { + /** + * 获取所有知识库 + * @returns {Promise} - 知识库列表 + */ + getDatabases: async () => { + checkAdminPermission() + return apiGet('/api/data/', {}, true) + }, + + /** + * 创建知识库 + * @param {Object} databaseData - 知识库数据 + * @returns {Promise} - 创建结果 + */ + createDatabase: async (databaseData) => { + checkAdminPermission() + return apiPost('/api/data/', databaseData, {}, true) + }, + + /** + * 获取知识库详情 + * @param {string} dbId - 知识库ID + * @returns {Promise} - 知识库详情 + */ + getDatabaseInfo: async (dbId) => { + checkAdminPermission() + return apiGet(`/api/data/info?db_id=${dbId}`, {}, true) + }, + + /** + * 删除知识库 + * @param {string} dbId - 知识库ID + * @returns {Promise} - 删除结果 + */ + deleteDatabase: async (dbId) => { + checkAdminPermission() + return apiDelete(`/api/data/?db_id=${dbId}`, {}, true) + }, + + /** + * 上传文件到知识库 + * @param {FormData} formData - 包含文件的FormData + * @param {string} dbId - 知识库ID + * @returns {Promise} - 上传结果 + */ + uploadFile: async (formData, dbId) => { + checkAdminPermission() + return fetch(`/api/data/upload?db_id=${dbId}`, { + method: 'POST', + headers: { + ...useUserStore().getAuthHeaders() + }, + body: formData + }).then(res => res.json()) + }, + + /** + * 删除文件 + * @param {string} dbId - 知识库ID + * @param {string} fileId - 文件ID + * @returns {Promise} - 删除结果 + */ + deleteFile: async (dbId, fileId) => { + checkAdminPermission() + return apiDelete('/api/data/document', { + body: JSON.stringify({ db_id: dbId, file_id: fileId }) + }, true) + }, + + /** + * 将文件分块 + * @param {Object} data - 分块参数 + * @returns {Promise} - 分块结果 + */ + fileToChunk: async (data) => { + checkAdminPermission() + return apiPost('/api/data/file-to-chunk', data, {}, true) + }, + + /** + * 将分块添加到数据库 + * @param {Object} data - 包含db_id和file_chunks的数据 + * @returns {Promise} - 添加结果 + */ + addByChunks: async (data) => { + checkAdminPermission() + return apiPost('/api/data/add-by-chunks', data, {}, true) + }, + + /** + * 查询测试 + * @param {Object} data - 查询参数 + * @returns {Promise} - 查询结果 + */ + queryTest: async (data) => { + checkAdminPermission() + return apiPost('/api/data/query-test', data, {}, true) + }, +} + +// 图数据库管理API +export const graphApi = { + /** + * 获取图数据库状态 + * @returns {Promise} - 图数据库状态 + */ + getGraphInfo: async () => { + checkAdminPermission() + return apiGet('/api/data/graph', {}, true) + }, + + /** + * 获取节点 + * @param {string} dbName - 图数据库名称 + * @param {number} num - 节点数量 + * @returns {Promise} - 节点数据 + */ + getNodes: async (dbName, num) => { + checkAdminPermission() + return apiGet(`/api/data/graph/nodes?kgdb_name=${dbName}&num=${num}`, {}, true) + }, + + /** + * 查询实体 + * @param {string} entityName - 实体名称 + * @returns {Promise} - 查询结果 + */ + queryNode: async (entityName) => { + checkAdminPermission() + return apiGet(`/api/data/graph/node?entity_name=${entityName}`, {}, true) + }, + + /** + * 添加JSONL文件到图数据库 + * @param {string} filePath - 文件路径 + * @returns {Promise} - 添加结果 + */ + addByJsonl: async (filePath) => { + checkAdminPermission() + return apiPost('/api/data/graph/add-by-jsonl', { file_path: filePath }, {}, true) + }, + + /** + * 为未索引节点添加索引 + * @param {string} dbName - 图数据库名称 + * @returns {Promise} - 索引结果 + */ + indexNodes: async (dbName) => { + checkAdminPermission() + return apiPost('/api/data/graph/index-nodes', { kgdb_name: dbName }, {}, true) + }, +} + +// 系统配置API +export const systemConfigApi = { + /** + * 设置默认智能体 + * @param {string} agentId - 智能体ID + * @returns {Promise} - 设置结果 + */ + setDefaultAgent: async (agentId) => { + checkAdminPermission() + return apiPost('/api/chat/set_default_agent', { agent_id: agentId }, {}, true) + }, + + /** + * 获取系统配置 + * @returns {Promise} - 系统配置 + */ + getSystemConfig: async () => { + checkAdminPermission() + return apiGet('/api/config', {}, true) + }, + + /** + * 更新系统配置 + * @param {Object} config - 配置项 + * @returns {Promise} - 更新结果 + */ + updateSystemConfig: async (config) => { + checkAdminPermission() + return apiPut('/api/config', config, {}, true) + }, + + /** + * 重启服务 + * @returns {Promise} - 重启结果 + */ + restartServer: async () => { + checkSuperAdminPermission() + return apiPost('/api/restart', {}, {}, true) + } +} + +// 日志API +export const logApi = { + /** + * 获取系统日志 + * @param {Object} params - 日志查询参数 + * @returns {Promise} - 日志数据 + */ + getLogs: async (params = {}) => { + checkAdminPermission() + return apiGet('/api/admin/logs', { params }, true) + }, +} \ No newline at end of file diff --git a/web/src/apis/auth_api.js b/web/src/apis/auth_api.js new file mode 100644 index 000000000..f971d5c7e --- /dev/null +++ b/web/src/apis/auth_api.js @@ -0,0 +1,106 @@ +import { apiGet, apiPost, apiDelete } from './base' + +/** + * 需要用户认证的API模块 + * 用户必须登录才能访问的API + * 权限要求: 任何已登录用户(普通用户、管理员、超级管理员) + */ + +// 聊天相关API +export const chatApi = { + /** + * 发送聊天消息 + * @param {Object} params - 聊天参数 + * @returns {Promise} - 聊天响应流 + */ + sendMessage: (params) => { + return fetch('/api/chat/', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(params), + }) + }, + + /** + * 简单聊天调用(非流式) + * @param {string} query - 查询内容 + * @returns {Promise} - 聊天响应 + */ + simpleCall: (query) => apiPost('/api/chat/call', { query }, {}, true), + + /** + * 获取默认智能体 + * @returns {Promise} - 默认智能体信息 + */ + getDefaultAgent: () => apiGet('/api/chat/default_agent', {}, true), + + /** + * 获取智能体列表 + * @returns {Promise} - 智能体列表 + */ + getAgents: () => apiGet('/api/chat/agent', {}, true), + + /** + * 获取单个智能体详情 + * @param {string} agentId - 智能体ID + * @returns {Promise} - 智能体详情 + */ + getAgentDetail: (agentId) => apiGet(`/api/chat/agent/${agentId}`, {}, true), + + /** + * 获取可用工具列表 + * @returns {Promise} - 工具列表 + */ + getTools: () => apiGet('/api/chat/tools', {}, true), + + /** + * 获取对话历史 + * @returns {Promise} - 对话历史列表 + */ + getConversations: () => apiGet('/api/chat/conversations', {}, true), + + /** + * 获取特定对话 + * @param {string} conversationId - 对话ID + * @returns {Promise} - 对话详情 + */ + getConversation: (conversationId) => + apiGet(`/api/chat/conversations/${conversationId}`, {}, true), + + /** + * 删除对话 + * @param {string} conversationId - 对话ID + * @returns {Promise} - 删除结果 + */ + deleteConversation: (conversationId) => + apiDelete(`/api/chat/conversations/${conversationId}`, {}, true), + + /** + * 更新对话标题 + * @param {string} conversationId - 对话ID + * @param {string} title - 新标题 + * @returns {Promise} - 更新结果 + */ + updateConversationTitle: (conversationId, title) => + apiPost(`/api/chat/conversations/${conversationId}/title`, { title }, {}, true), +} + +// 用户设置API +export const userSettingsApi = { + /** + * 获取用户设置 + * @returns {Promise} - 用户设置 + */ + getSettings: () => apiGet('/api/user/settings', {}, true), + + /** + * 更新用户设置 + * @param {Object} settings - 新设置 + * @returns {Promise} - 更新结果 + */ + updateSettings: (settings) => apiPost('/api/user/settings', settings, {}, true), +} + +// 其他需要用户认证的API可以继续添加到这里 \ No newline at end of file diff --git a/web/src/apis/base.js b/web/src/apis/base.js new file mode 100644 index 000000000..78e4bdf82 --- /dev/null +++ b/web/src/apis/base.js @@ -0,0 +1,142 @@ +import { useUserStore } from '@/stores/user' +import { message } from 'ant-design-vue' + +/** + * 基础API请求封装 + * 提供统一的请求方法,自动处理认证头和错误 + */ + +/** + * 发送API请求的基础函数 + * @param {string} url - API端点 + * @param {Object} options - 请求选项 + * @param {boolean} requiresAuth - 是否需要认证头 + * @returns {Promise} - 请求结果 + */ +export async function apiRequest(url, options = {}, requiresAuth = false) { + try { + // 默认请求配置 + const requestOptions = { + ...options, + headers: { + 'Content-Type': 'application/json', + ...options.headers, + }, + } + + // 如果需要认证,添加认证头 + if (requiresAuth) { + const userStore = useUserStore() + if (!userStore.isLoggedIn) { + throw new Error('用户未登录') + } + + Object.assign(requestOptions.headers, userStore.getAuthHeaders()) + } + + // 发送请求 + const response = await fetch(url, requestOptions) + + // 处理API返回的错误 + if (!response.ok) { + // 尝试解析错误信息 + let errorMessage = `请求失败: ${response.status}` + try { + const errorData = await response.json() + errorMessage = errorData.detail || errorData.message || errorMessage + } catch (e) { + // 如果无法解析JSON,使用默认错误信息 + } + + // 特殊处理401和403错误 + if (response.status === 401) { + // 如果是认证失败,可能需要重新登录 + const userStore = useUserStore() + if (userStore.isLoggedIn) { + // 如果用户认为自己已登录,但收到401,则可能是令牌过期 + message.error('登录已过期,请重新登录') + userStore.logout() + window.location.href = '/login' + } + throw new Error('未授权,请先登录') + } else if (response.status === 403) { + throw new Error('没有权限执行此操作') + } + + throw new Error(errorMessage) + } + + // 检查Content-Type以确定如何处理响应 + const contentType = response.headers.get('Content-Type') + if (contentType && contentType.includes('application/json')) { + return await response.json() + } + + return await response.text() + } catch (error) { + console.error('API请求错误:', error) + throw error + } +} + +/** + * 发送GET请求 + * @param {string} url - API端点 + * @param {Object} options - 请求选项 + * @param {boolean} requiresAuth - 是否需要认证 + * @returns {Promise} - 请求结果 + */ +export function apiGet(url, options = {}, requiresAuth = false) { + return apiRequest(url, { method: 'GET', ...options }, requiresAuth) +} + +/** + * 发送POST请求 + * @param {string} url - API端点 + * @param {Object} data - 请求体数据 + * @param {Object} options - 其他请求选项 + * @param {boolean} requiresAuth - 是否需要认证 + * @returns {Promise} - 请求结果 + */ +export function apiPost(url, data = {}, options = {}, requiresAuth = false) { + return apiRequest( + url, + { + method: 'POST', + body: JSON.stringify(data), + ...options + }, + requiresAuth + ) +} + +/** + * 发送PUT请求 + * @param {string} url - API端点 + * @param {Object} data - 请求体数据 + * @param {Object} options - 其他请求选项 + * @param {boolean} requiresAuth - 是否需要认证 + * @returns {Promise} - 请求结果 + */ +export function apiPut(url, data = {}, options = {}, requiresAuth = false) { + return apiRequest( + url, + { + method: 'PUT', + body: JSON.stringify(data), + ...options + }, + requiresAuth + ) +} + +/** + * 发送DELETE请求 + * @param {string} url - API端点 + * @param {Object} options - 请求选项 + * @param {boolean} requiresAuth - 是否需要认证 + * @returns {Promise} - 请求结果 + */ +export function apiDelete(url, options = {}, requiresAuth = false) { + return apiRequest(url, { method: 'DELETE', ...options }, requiresAuth) +} \ No newline at end of file diff --git a/web/src/apis/index.js b/web/src/apis/index.js new file mode 100644 index 000000000..ffa124ccb --- /dev/null +++ b/web/src/apis/index.js @@ -0,0 +1,33 @@ +/** + * API模块索引文件 + * 导出所有API模块,方便统一引入 + */ + +// 导出公共API模块 +export * from './public_api' + +// 导出需要用户认证的API模块 +export * from './auth_api' + +// 导出需要管理员权限的API模块 +export * from './admin_api' + +// 导出基础工具函数 +export { apiRequest, apiGet, apiPost, apiPut, apiDelete } from './base' + +/** + * 权限说明: + * + * 1. public_api.js: 不需要认证就可以访问的API + * - 登录、初始化管理员、获取公共配置等 + * + * 2. auth_api.js: 需要用户认证才能访问的API + * - 权限要求: 任何已登录用户(普通用户、管理员、超级管理员) + * - 聊天功能、个人设置等 + * + * 3. admin_api.js: 需要管理员权限才能访问的API + * - 权限要求: admin 或 superadmin + * - 用户管理、知识库管理、系统配置等 + * + * 注意:本模块已处理权限验证和请求头,使用时无需再手动添加认证头 + */ \ No newline at end of file diff --git a/web/src/apis/public_api.js b/web/src/apis/public_api.js new file mode 100644 index 000000000..88d759dea --- /dev/null +++ b/web/src/apis/public_api.js @@ -0,0 +1,59 @@ +import { apiGet, apiPost } from './base' + +/** + * 公共API模块 + * 包含所有不需要认证的公共接口 + */ + +// 登录相关API +export const authApi = { + /** + * 用户登录 + * @param {Object} credentials - 登录凭证 + * @returns {Promise} - 登录结果 + */ + login: (credentials) => { + const formData = new FormData() + formData.append('username', credentials.username) + formData.append('password', credentials.password) + + return apiRequest('/api/auth/token', { + method: 'POST', + body: formData + }, false) + }, + + /** + * 检查是否是首次运行 + * @returns {Promise} - 是否首次运行 + */ + checkFirstRun: () => apiGet('/api/auth/check-first-run'), + + /** + * 初始化管理员账户 + * @param {Object} adminData - 管理员账户数据 + * @returns {Promise} - 初始化结果 + */ + initializeAdmin: (adminData) => apiPost('/api/auth/initialize', adminData), +} + +// 配置相关API +export const configApi = { + /** + * 获取系统配置 + * @returns {Promise} - 系统配置 + */ + getConfig: () => apiGet('/api/config'), +} + +// 健康检查API +export const healthApi = { + /** + * 系统健康检查 + * @returns {Promise} - 健康检查结果 + */ + check: () => apiGet('/api/health'), +} + +// 从base.js导入apiRequest以支持FormData +import { apiRequest } from './base' \ No newline at end of file diff --git a/web/src/components/AgentChatComponent.vue b/web/src/components/AgentChatComponent.vue index 5272d74be..6ca3c1b09 100644 --- a/web/src/components/AgentChatComponent.vue +++ b/web/src/components/AgentChatComponent.vue @@ -110,6 +110,8 @@ import { import { message } from 'ant-design-vue'; import MessageInputComponent from '@/components/MessageInputComponent.vue' import MessageComponent from '@/components/MessageComponent.vue' +import { useUserStore } from '@/stores/user' +import { chatApi } from '@/apis/auth_api' // 新增props属性,允许父组件传入agentId const props = defineProps({ @@ -127,6 +129,9 @@ const props = defineProps({ } }); +// 初始化userStore +const userStore = useUserStore(); + // ==================== 状态管理 ==================== // UI状态 @@ -356,7 +361,10 @@ const sendMessageWithText = async (text) => { // 发送请求 const response = await fetch(`/api/chat/agent/${currentAgent.value.name}`, { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: { + 'Content-Type': 'application/json', + ...userStore.getAuthHeaders() + }, body: JSON.stringify(requestData) }); @@ -752,18 +760,13 @@ const appendToolMessageToExistingAssistant = async (data) => { // 获取智能体列表 const fetchAgents = async () => { try { - const response = await fetch('/api/chat/agent'); - if (response.ok) { - const data = await response.json(); - // 将数组转换为对象 - agents.value = data.agents.reduce((acc, agent) => { - acc[agent.name] = agent; - return acc; - }, {}); - console.log("agents", agents.value); - } else { - console.error('获取智能体失败'); - } + const data = await chatApi.getAgents(); + // 将数组转换为对象 + agents.value = data.agents.reduce((acc, agent) => { + acc[agent.name] = agent; + return acc; + }, {}); + console.log("agents", agents.value); } catch (error) { console.error('获取智能体错误:', error); } diff --git a/web/src/components/ChatComponent.vue b/web/src/components/ChatComponent.vue index 6ef374008..b314cfc4a 100644 --- a/web/src/components/ChatComponent.vue +++ b/web/src/components/ChatComponent.vue @@ -184,10 +184,12 @@ import { } from '@ant-design/icons-vue' import { onClickOutside } from '@vueuse/core' import { useConfigStore } from '@/stores/config' +import { useUserStore } from '@/stores/user' import { message } from 'ant-design-vue' import MessageInputComponent from '@/components/MessageInputComponent.vue' import MessageComponent from '@/components/MessageComponent.vue' import RefsSidebar from '@/components/RefsSidebar.vue' +import { chatApi } from '@/apis/auth_api' const props = defineProps({ conv: Object, @@ -196,6 +198,7 @@ const props = defineProps({ const emit = defineEmits(['rename-title', 'newconv']); const configStore = useConfigStore() +const userStore = useUserStore() const { conv, state } = toRefs(props) const chatContainer = ref(null) @@ -475,50 +478,56 @@ const groupRefs = (id) => { scrollToBottom() } -const simpleCall = (msg) => { - return new Promise((resolve, reject) => { - fetch('/api/chat/call', { - method: 'POST', - body: JSON.stringify({ query: msg, }), - headers: { 'Content-Type': 'application/json' } - }) - .then((response) => response.json()) - .then((data) => resolve(data)) - .catch((error) => reject(error)) - }) -} - const loadDatabases = () => { - fetch('/api/data/', { method: "GET", }) + fetch('/api/data/', { + method: "GET", + headers: userStore.getAuthHeaders() + }) .then(response => response.json()) .then(data => { console.log(data) opts.databases = data.databases }) + .catch(error => { + console.error('加载数据库列表失败:', error) + }) } -// 新函数用于处理 fetch 请求 +const simpleCall = (msg) => { + return new Promise((resolve, reject) => { + chatApi.simpleCall(msg) + .then(data => resolve(data)) + .catch(error => reject(error)) + }) +} + +// 替换fetchChatResponse函数 const fetchChatResponse = (user_input, cur_res_id) => { const controller = new AbortController(); const signal = controller.signal; const params = { query: user_input, - history: getHistory().slice(0, -1), // 去掉最后一条刚添加的用户消息, + history: getHistory().slice(0, -1), // 去掉最后一条刚添加的用户消息 meta: meta, cur_res_id: cur_res_id, } console.log(params) + // 使用fetch带上认证头和信号控制 fetch('/api/chat/', { method: 'POST', body: JSON.stringify(params), headers: { - 'Content-Type': 'application/json' + 'Content-Type': 'application/json', + ...userStore.getAuthHeaders() }, signal // 添加 signal 用于中断请求 }) .then((response) => { + if (!response.ok) { + throw new Error(`请求失败: ${response.status} ${response.statusText}`) + } if (!response.body) throw new Error("ReadableStream not supported."); const reader = response.body.getReader(); const decoder = new TextDecoder("utf-8"); @@ -587,6 +596,7 @@ const fetchChatResponse = (user_input, cur_res_id) => { updateMessage({ id: cur_res_id, status: "error", + message: error.message || '请求失败', }); } isStreaming.value = false; diff --git a/web/src/components/TokenManagerComponent.vue b/web/src/components/TokenManagerComponent.vue index 4b074f859..c29b7d4ea 100644 --- a/web/src/components/TokenManagerComponent.vue +++ b/web/src/components/TokenManagerComponent.vue @@ -61,6 +61,7 @@ import { ref, onMounted, watch } from 'vue'; import { message, Empty } from 'ant-design-vue'; import { PlusOutlined, DeleteOutlined, CopyOutlined } from '@ant-design/icons-vue'; +import { tokenApi } from '@/apis/admin_api'; const props = defineProps({ agentId: { @@ -81,16 +82,11 @@ const newToken = ref({ const fetchTokens = async () => { loading.value = true; try { - const response = await fetch(`/api/admin/tokens?agent_id=${props.agentId}`); - if (response.ok) { - const data = await response.json(); - tokens.value = data; - } else { - message.error('获取令牌列表失败'); - } + const data = await tokenApi.getTokens(props.agentId); + tokens.value = data; } catch (error) { console.error('获取令牌列表出错:', error); - message.error('获取令牌列表出错'); + message.error(error.message || '获取令牌列表出错'); } finally { loading.value = false; } @@ -104,48 +100,26 @@ const createToken = async () => { } try { - const response = await fetch('/api/admin/tokens', { - method: 'POST', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ - agent_id: props.agentId, - name: newToken.value.name - }) - }); - - if (response.ok) { - const data = await response.json(); - tokens.value.push(data); - message.success('令牌创建成功'); - addTokenModalVisible.value = false; - newToken.value.name = ''; - } else { - message.error('创建令牌失败'); - } + const data = await tokenApi.createToken(props.agentId, newToken.value.name); + tokens.value.push(data); + message.success('令牌创建成功'); + addTokenModalVisible.value = false; + newToken.value.name = ''; } catch (error) { console.error('创建令牌出错:', error); - message.error('创建令牌出错'); + message.error(error.message || '创建令牌出错'); } }; // 删除令牌 const deleteToken = async (tokenId) => { try { - const response = await fetch(`/api/admin/tokens/${tokenId}`, { - method: 'DELETE' - }); - - if (response.ok) { - tokens.value = tokens.value.filter(token => token.id !== tokenId); - message.success('令牌已删除'); - } else { - message.error('删除令牌失败'); - } + await tokenApi.deleteToken(tokenId); + tokens.value = tokens.value.filter(token => token.id !== tokenId); + message.success('令牌已删除'); } catch (error) { console.error('删除令牌出错:', error); - message.error('删除令牌出错'); + message.error(error.message || '删除令牌出错'); } }; diff --git a/web/src/components/UserInfoComponent.vue b/web/src/components/UserInfoComponent.vue new file mode 100644 index 000000000..9afe0d941 --- /dev/null +++ b/web/src/components/UserInfoComponent.vue @@ -0,0 +1,139 @@ + + + + + \ No newline at end of file diff --git a/web/src/layouts/AppLayout.vue b/web/src/layouts/AppLayout.vue index e29115ce1..e17bf6b34 100644 --- a/web/src/layouts/AppLayout.vue +++ b/web/src/layouts/AppLayout.vue @@ -29,6 +29,8 @@ import { themeConfig } from '@/assets/theme' import { useConfigStore } from '@/stores/config' import { useDatabaseStore } from '@/stores/database' import DebugComponent from '@/components/DebugComponent.vue' +import UserInfoComponent from '@/components/UserInfoComponent.vue' +import { configApi } from '@/apis/public_api' const configStore = useConfigStore() const databaseStore = useDatabaseStore() @@ -57,11 +59,12 @@ const getRemoteDatabase = () => { const fetchGithubStars = async () => { try { isLoadingStars.value = true + // 公共API,可以直接使用fetch const response = await fetch('https://api.github.com/repos/xerrors/Yuxi-Know') const data = await response.json() githubStars.value = data.stargazers_count } catch (error) { - console.error('Error fetching GitHub stars:', error) + console.error('获取GitHub stars失败:', error) } finally { isLoadingStars.value = false } @@ -102,8 +105,8 @@ const mainList = [{ activeIcon: BookFilled, // hidden: !configStore.config.enable_knowledge_base, }, { - name: '工具', - path: '/tools', + name: '智能体', + path: '/agent', icon: ToolOutlined, activeIcon: ToolFilled, } @@ -163,6 +166,10 @@ const mainList = [{
+ + + +