diff --git a/README.md b/README.md index cf3984469..53784b014 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,12 @@ ## 📝 项目概述 +DEV 更新待办: + +- 智能体的消息加载有问题 +- 智能体的管理员的配置无法更新到用户层面 + + 语析是一个强大的问答平台,结合了大模型 RAG 知识库与知识图谱技术,基于 Llamaindex + VueJS + FastAPI + Neo4j 构建。 **核心特点:** diff --git a/pyproject.toml b/pyproject.toml index 64cc709b1..7651724ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,9 +19,11 @@ dependencies = [ "openai>=1.76.0", "opencv-python-headless>=4.11.0.86", "paddleocr>=2.10.0", + "pyjwt>=2.8.0", "pymilvus>=2.5.8", "pymupdf>=1.25.5", "python-dotenv>=1.1.0", + "python-jose[cryptography]>=3.4.0", "python-multipart>=0.0.20", "pyyaml>=6.0.2", "qianfan>=0.4.12.3", diff --git a/requirements.txt b/requirements.txt index 790c4f8ab..e0a7a1374 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,4 +31,5 @@ langchain langsmith langgraph langchain-openai -langchain-community \ No newline at end of file +langchain-community +PyJWT>=2.10.1 \ No newline at end of file diff --git a/server/db_manager.py b/server/db_manager.py index 0a05bcb91..af1350fdb 100644 --- a/server/db_manager.py +++ b/server/db_manager.py @@ -1,11 +1,10 @@ import os -import sqlite3 import pathlib from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from sqlalchemy.ext.declarative import declarative_base -from server.models.token_model import Base, AgentToken +from server.models import Base +from server.models.user_model import User class DBManager: """数据库管理器""" @@ -32,6 +31,15 @@ def create_tables(self): """创建数据库表""" Base.metadata.create_all(self.engine) + def check_first_run(self): + """检查是否首次运行""" + session = self.get_session() + try: + # 检查是否有任何用户存在 + return session.query(User).count() == 0 + finally: + session.close() + def get_session(self): """获取数据库会话""" return self.Session() diff --git a/server/main.py b/server/main.py index 93277ccf1..bfcc41db0 100644 --- a/server/main.py +++ b/server/main.py @@ -1,13 +1,17 @@ import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, Request, HTTPException, status, Depends from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware + from server.routers import router +from server.utils.auth_middleware import is_public_path from src.utils.logging_config import logger app = FastAPI() -app.include_router(router) +app.include_router(router, prefix="/api") # CORS 设置 app.add_middleware( @@ -18,6 +22,40 @@ allow_headers=["*"], ) +# 鉴权中间件 +class AuthMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # 获取请求路径 + path = request.url.path + + # 检查是否为公开路径,公开路径无需身份验证 + if is_public_path(path): + return await call_next(request) + + if not path.startswith("/api"): + # 非API路径,可能是前端路由或静态资源 + return await call_next(request) + + # # 提取Authorization头 + # auth_header = request.headers.get("Authorization") + # if not auth_header or not auth_header.startswith("Bearer "): + # return JSONResponse( + # status_code=status.HTTP_401_UNAUTHORIZED, + # content={"detail": f"请先登录。Path: {path}"}, + # headers={"WWW-Authenticate": "Bearer"} + # ) + + # # 获取token + # token = auth_header.split("Bearer ")[1] + + # # 添加token到请求状态,后续路由可以直接使用 + # request.state.token = token + + # 继续处理请求 + return await call_next(request) + +# 添加鉴权中间件 +app.add_middleware(AuthMiddleware) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=5050, threads=10, workers=10, reload=True) diff --git a/server/models/__init__.py b/server/models/__init__.py new file mode 100644 index 000000000..7c2377aec --- /dev/null +++ b/server/models/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() \ No newline at end of file diff --git a/server/models/token_model.py b/server/models/token_model.py deleted file mode 100644 index 74ee1d596..000000000 --- a/server/models/token_model.py +++ /dev/null @@ -1,24 +0,0 @@ -from sqlalchemy import Column, Integer, String, DateTime, ForeignKey -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.sql import func - -Base = declarative_base() - -class AgentToken(Base): - """智能体访问令牌模型""" - __tablename__ = 'agent_tokens' - - id = Column(Integer, primary_key=True, autoincrement=True) - agent_id = Column(String, nullable=False, index=True) # 智能体ID - name = Column(String, nullable=False) # 令牌名称 - token = Column(String, nullable=False, unique=True) # 令牌值 - created_at = Column(DateTime, default=func.now()) # 创建时间 - - def to_dict(self): - return { - "id": self.id, - "agent_id": self.agent_id, - "name": self.name, - "token": self.token, - "created_at": self.created_at.isoformat() if self.created_at else None - } \ No newline at end of file diff --git a/server/models/user_model.py b/server/models/user_model.py new file mode 100644 index 000000000..527d1f817 --- /dev/null +++ b/server/models/user_model.py @@ -0,0 +1,55 @@ +from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Text +from sqlalchemy.sql import func +from sqlalchemy.orm import relationship + +from server.models import Base + +class User(Base): + """用户模型""" + __tablename__ = 'users' + + id = Column(Integer, primary_key=True, autoincrement=True) + username = Column(String, nullable=False, unique=True, index=True) + password_hash = Column(String, nullable=False) + role = Column(String, nullable=False, default='user') # 角色: superadmin, admin, user + created_at = Column(DateTime, default=func.now()) + last_login = Column(DateTime, nullable=True) + + # 关联操作日志 + operation_logs = relationship("OperationLog", back_populates="user") + + def to_dict(self, include_password=False): + result = { + "id": self.id, + "username": self.username, + "role": self.role, + "created_at": self.created_at.isoformat() if self.created_at else None, + "last_login": self.last_login.isoformat() if self.last_login else None + } + if include_password: + result["password_hash"] = self.password_hash + return result + +class OperationLog(Base): + """操作日志模型""" + __tablename__ = 'operation_logs' + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey('users.id'), nullable=False) + operation = Column(String, nullable=False) + details = Column(Text, nullable=True) + ip_address = Column(String, nullable=True) + timestamp = Column(DateTime, default=func.now()) + + # 关联用户 + user = relationship("User", back_populates="operation_logs") + + def to_dict(self): + return { + "id": self.id, + "user_id": self.user_id, + "operation": self.operation, + "details": self.details, + "ip_address": self.ip_address, + "timestamp": self.timestamp.isoformat() if self.timestamp else None + } \ No newline at end of file diff --git a/server/routers/__init__.py b/server/routers/__init__.py index 6878d9347..017bb2b52 100644 --- a/server/routers/__init__.py +++ b/server/routers/__init__.py @@ -2,12 +2,10 @@ from server.routers.chat_router import chat from server.routers.data_router import data from server.routers.base_router import base -from server.routers.tool_router import tool -from server.routers.admin_router import admin +from server.routers.auth_router import auth router = APIRouter() router.include_router(base) router.include_router(chat) router.include_router(data) -router.include_router(tool) -router.include_router(admin) +router.include_router(auth) diff --git a/server/routers/admin_router.py b/server/routers/admin_router.py deleted file mode 100644 index 9471b6bc3..000000000 --- a/server/routers/admin_router.py +++ /dev/null @@ -1,103 +0,0 @@ -import secrets -import string -from fastapi import APIRouter, Depends, HTTPException, Query -from pydantic import BaseModel -from typing import List, Optional -from sqlalchemy.orm import Session - -from server.db_manager import db_manager -from server.models.token_model import AgentToken - -admin = APIRouter(prefix="/admin", tags=["admin"]) - -# 依赖项:获取数据库会话 -def get_db(): - db = db_manager.get_session() - try: - yield db - finally: - db.close() - -# 请求和响应模型 -class TokenCreate(BaseModel): - agent_id: str - name: str - -class TokenVerify(BaseModel): - agent_id: str - token: str - -class TokenResponse(BaseModel): - id: int - agent_id: str - name: str - token: str - created_at: str - -# 生成随机token -def generate_token(length=32): - alphabet = string.ascii_letters + string.digits - return ''.join(secrets.choice(alphabet) for _ in range(length)) - -@admin.get("/tokens", response_model=List[TokenResponse]) -async def get_agent_tokens( - agent_id: Optional[str] = Query(None), - db: Session = Depends(get_db) -): - """获取智能体的token列表""" - query = db.query(AgentToken) - if agent_id: - query = query.filter(AgentToken.agent_id == agent_id) - tokens = query.all() - return [token.to_dict() for token in tokens] - -@admin.post("/tokens", response_model=TokenResponse) -async def create_token( - token_data: TokenCreate, - db: Session = Depends(get_db) -): - """创建新的token""" - # 生成随机token - token_value = generate_token() - - # 创建token记录 - new_token = AgentToken( - agent_id=token_data.agent_id, - name=token_data.name, - token=token_value - ) - - # 保存到数据库 - db.add(new_token) - db.commit() - db.refresh(new_token) - - return new_token.to_dict() - -@admin.delete("/tokens/{token_id}", response_model=dict) -async def delete_token(token_id: int, db: Session = Depends(get_db)): - """删除token""" - token = db.query(AgentToken).filter(AgentToken.id == token_id).first() - if not token: - raise HTTPException(status_code=404, detail="Token not found") - - db.delete(token) - db.commit() - - return {"success": True, "message": "Token deleted"} - -@admin.post("/verify_token") -async def verify_agent_token( - token_data: TokenVerify, - db: Session = Depends(get_db) -): - """验证智能体访问令牌""" - token = db.query(AgentToken).filter( - AgentToken.agent_id == token_data.agent_id, - AgentToken.token == token_data.token - ).first() - - if not token: - raise HTTPException(status_code=401, detail="Invalid token") - - return {"success": True, "message": "Token verified"} \ No newline at end of file diff --git a/server/routers/auth_router.py b/server/routers/auth_router.py new file mode 100644 index 000000000..08654bf81 --- /dev/null +++ b/server/routers/auth_router.py @@ -0,0 +1,347 @@ +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.security import OAuth2PasswordRequestForm +from pydantic import BaseModel +from typing import List, Optional +from sqlalchemy.orm import Session +from datetime import datetime, timedelta + +from server.db_manager import db_manager +from server.models.user_model import User, OperationLog +from server.utils.auth_utils import AuthUtils +from server.utils.auth_middleware import get_db, get_current_user, get_admin_user, get_superadmin_user, oauth2_scheme + +# 创建路由器 +auth = APIRouter(prefix="/auth", tags=["auth"]) + +# 请求和响应模型 +class Token(BaseModel): + access_token: str + token_type: str + user_id: int + username: str + role: str + +class UserCreate(BaseModel): + username: str + password: str + role: str = "user" + +class UserUpdate(BaseModel): + username: Optional[str] = None + password: Optional[str] = None + role: Optional[str] = None + +class UserResponse(BaseModel): + id: int + username: str + role: str + created_at: str + last_login: Optional[str] = None + +class InitializeAdmin(BaseModel): + username: str + password: str + +# 记录操作日志 +def log_operation(db: Session, user_id: int, operation: str, details: str = None, request: Request = None): + ip_address = None + if request: + ip_address = request.client.host if request.client else None + + log = OperationLog( + user_id=user_id, + operation=operation, + details=details, + ip_address=ip_address + ) + db.add(log) + db.commit() + +# 路由:登录获取令牌 +@auth.post("/token", response_model=Token) +async def login_for_access_token( + form_data: OAuth2PasswordRequestForm = Depends(), + db: Session = Depends(get_db) +): + # 查找用户 + user = db.query(User).filter(User.username == form_data.username).first() + + # 验证用户存在且密码正确 + if not user or not AuthUtils.verify_password(user.password_hash, form_data.password): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="用户名或密码错误", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # 更新最后登录时间 + user.last_login = datetime.now() + db.commit() + + # 生成访问令牌 + token_data = {"sub": str(user.id)} + access_token = AuthUtils.create_access_token(token_data) + + # 记录登录操作 + log_operation(db, user.id, "登录") + + return { + "access_token": access_token, + "token_type": "bearer", + "user_id": user.id, + "username": user.username, + "role": user.role + } + +# 路由:校验是否需要初始化管理员 +@auth.get("/check-first-run") +async def check_first_run(): + is_first_run = db_manager.check_first_run() + return {"first_run": is_first_run} + +# 路由:初始化管理员账户 +@auth.post("/initialize", response_model=Token) +async def initialize_admin( + admin_data: InitializeAdmin, + db: Session = Depends(get_db) +): + # 检查是否是首次运行 + if not db_manager.check_first_run(): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="系统已经初始化,无法再次创建初始管理员", + ) + + # 创建管理员账户 + hashed_password = AuthUtils.hash_password(admin_data.password) + + new_admin = User( + username=admin_data.username, + password_hash=hashed_password, + role="superadmin", + last_login=datetime.now() + ) + + db.add(new_admin) + db.commit() + db.refresh(new_admin) + + # 生成访问令牌 + token_data = {"sub": str(new_admin.id)} + access_token = AuthUtils.create_access_token(token_data) + + # 记录操作 + log_operation(db, new_admin.id, "系统初始化", "创建超级管理员账户") + + return { + "access_token": access_token, + "token_type": "bearer", + "user_id": new_admin.id, + "username": new_admin.username, + "role": new_admin.role + } + +# 路由:获取当前用户信息 +@auth.get("/me", response_model=UserResponse) +async def read_users_me(current_user: User = Depends(get_current_user)): + return current_user.to_dict() + +# 路由:创建新用户(管理员权限) +@auth.post("/users", response_model=UserResponse) +async def create_user( + user_data: UserCreate, + request: Request, + current_user: User = Depends(get_admin_user), + db: Session = Depends(get_db) +): + # 检查用户名是否已存在 + existing_user = db.query(User).filter(User.username == user_data.username).first() + if existing_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="用户名已存在", + ) + + # 创建新用户 + hashed_password = AuthUtils.hash_password(user_data.password) + + # 检查角色权限 + # 超级管理员可以创建任何类型的用户 + if user_data.role == "superadmin" and current_user.role != "superadmin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="只有超级管理员才能创建超级管理员账户", + ) + + # 管理员只能创建普通用户 + if current_user.role == "admin" and user_data.role != "user": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="管理员只能创建普通用户账户", + ) + + new_user = User( + username=user_data.username, + password_hash=hashed_password, + role=user_data.role + ) + + db.add(new_user) + db.commit() + db.refresh(new_user) + + # 记录操作 + log_operation( + db, + current_user.id, + "创建用户", + f"创建用户: {user_data.username}, 角色: {user_data.role}", + request + ) + + return new_user.to_dict() + +# 路由:获取所有用户(管理员权限) +@auth.get("/users", response_model=List[UserResponse]) +async def read_users( + skip: int = 0, + limit: int = 100, + current_user: User = Depends(get_admin_user), + db: Session = Depends(get_db) +): + users = db.query(User).offset(skip).limit(limit).all() + return [user.to_dict() for user in users] + +# 路由:获取特定用户信息(管理员权限) +@auth.get("/users/{user_id}", response_model=UserResponse) +async def read_user( + user_id: int, + current_user: User = Depends(get_admin_user), + db: Session = Depends(get_db) +): + user = db.query(User).filter(User.id == user_id).first() + if user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="用户不存在", + ) + return user.to_dict() + +# 路由:更新用户信息(管理员权限) +@auth.put("/users/{user_id}", response_model=UserResponse) +async def update_user( + user_id: int, + user_data: UserUpdate, + request: Request, + current_user: User = Depends(get_admin_user), + db: Session = Depends(get_db) +): + user = db.query(User).filter(User.id == user_id).first() + if user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="用户不存在", + ) + + # 检查权限 + if user.role == "superadmin" and current_user.role != "superadmin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="只有超级管理员才能修改超级管理员账户", + ) + + # 超级管理员账户不能被降级(只能由其他超级管理员修改) + if user.role == "superadmin" and user_data.role and user_data.role != "superadmin" and current_user.id != user.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="不能降级超级管理员账户", + ) + + # 更新信息 + update_details = [] + + if user_data.username is not None: + # 检查用户名是否已被其他用户使用 + existing_user = db.query(User).filter(User.username == user_data.username, User.id != user_id).first() + if existing_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="用户名已存在", + ) + user.username = user_data.username + update_details.append(f"用户名: {user_data.username}") + + if user_data.password is not None: + user.password_hash = AuthUtils.hash_password(user_data.password) + update_details.append("密码已更新") + + if user_data.role is not None: + user.role = user_data.role + update_details.append(f"角色: {user_data.role}") + + db.commit() + + # 记录操作 + log_operation( + db, + current_user.id, + "更新用户", + f"更新用户ID {user_id}: {', '.join(update_details)}", + request + ) + + return user.to_dict() + +# 路由:删除用户(管理员权限) +@auth.delete("/users/{user_id}", response_model=dict) +async def delete_user( + user_id: int, + request: Request, + current_user: User = Depends(get_admin_user), + db: Session = Depends(get_db) +): + user = db.query(User).filter(User.id == user_id).first() + if user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="用户不存在", + ) + + # 检查权限 + if user.role == "superadmin": + # 只有超级管理员可以删除超级管理员 + if current_user.role != "superadmin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="只有超级管理员才能删除超级管理员账户", + ) + + # 检查是否是最后一个超级管理员 + superadmin_count = db.query(User).filter(User.role == "superadmin").count() + if superadmin_count <= 1: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="不能删除最后一个超级管理员账户", + ) + + # 不能删除自己的账户 + if user.id == current_user.id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="不能删除自己的账户", + ) + + # 记录操作 + log_operation( + db, + current_user.id, + "删除用户", + f"删除用户: {user.username}, ID: {user.id}, 角色: {user.role}", + request + ) + + # 删除用户 + db.delete(user) + db.commit() + + return {"success": True, "message": "用户已删除"} \ No newline at end of file diff --git a/server/routers/base_router.py b/server/routers/base_router.py index 8ee6f3283..ea70877b9 100644 --- a/server/routers/base_router.py +++ b/server/routers/base_router.py @@ -1,4 +1,4 @@ -from fastapi import Request, Body +from fastapi import Request, Body, Depends from fastapi import APIRouter from fastapi import Request, Body @@ -6,6 +6,8 @@ from src import config, retriever, knowledge_base, graph_base from src.utils import logger +from server.utils.auth_middleware import get_admin_user, get_superadmin_user +from server.models.user_model import User @base.get("/") @@ -13,27 +15,37 @@ async def route_index(): return {"message": "You Got It!"} @base.get("/config") -def get_config(): - return config.get_safe_config() +def get_config(current_user: User = Depends(get_admin_user)): + return config.dump_config() @base.post("/config") -async def update_config(key = Body(...), value = Body(...)): - if key == "custom_models": - value = config.compare_custom_models(value) - +async def update_config( + key = Body(...), + value = Body(...), + current_user: User = Depends(get_admin_user) +) -> dict: config[key] = value config.save() - return config.get_safe_config() + return config.dump_config() + +@base.post("/config/update") +async def update_config_item( + items: dict = Body(...), + current_user: User = Depends(get_admin_user) +) -> dict: + config.update(items) + config.save() + return config.dump_config() @base.post("/restart") -async def restart(): +async def restart(current_user: User = Depends(get_superadmin_user)): knowledge_base.restart() graph_base.start() retriever.restart() return {"message": "Restarted!"} @base.get("/log") -def get_log(): +def get_log(current_user: User = Depends(get_admin_user)): from src.utils.logging_config import LOG_FILE from collections import deque diff --git a/server/routers/chat_router.py b/server/routers/chat_router.py index 9c5fb82d3..8fe2fb23f 100644 --- a/server/routers/chat_router.py +++ b/server/routers/chat_router.py @@ -13,11 +13,54 @@ from src.models import select_model from src.utils.logging_config import logger from src.agents.tools_factory import get_all_tools +from server.routers.auth_router import get_admin_user +from server.utils.auth_middleware import get_required_user +from server.models.user_model import User chat = APIRouter(prefix="/chat") +@chat.get("/default_agent") +async def get_default_agent(current_user: User = Depends(get_required_user)): + """获取默认智能体ID(需要登录)""" + try: + default_agent_id = config.default_agent_id + # 如果没有设置默认智能体,尝试获取第一个可用的智能体 + if not default_agent_id: + agents = [agent.get_info() for agent in agent_manager.agents.values()] + if agents: + default_agent_id = agents[0].get("name", "") + + return {"default_agent_id": default_agent_id} + except Exception as e: + logger.error(f"获取默认智能体出错: {e}") + raise HTTPException(status_code=500, detail=f"获取默认智能体出错: {str(e)}") + +@chat.post("/set_default_agent") +async def set_default_agent(agent_id: str = Body(..., embed=True), current_user = Depends(get_admin_user)): + """设置默认智能体ID (仅管理员)""" + try: + # 验证智能体是否存在 + agents = [agent.get_info() for agent in agent_manager.agents.values()] + agent_ids = [agent.get("name", "") for agent in agents] + + if agent_id not in agent_ids: + raise HTTPException(status_code=404, detail=f"智能体 {agent_id} 不存在") + + # 设置默认智能体ID + config.default_agent_id = agent_id + # 保存配置 + config.save() + + return {"success": True, "default_agent_id": agent_id} + except HTTPException as he: + raise he + except Exception as e: + logger.error(f"设置默认智能体出错: {e}") + raise HTTPException(status_code=500, detail=f"设置默认智能体出错: {str(e)}") + @chat.get("/") -async def chat_get(): +async def chat_get(current_user: User = Depends(get_required_user)): + """聊天服务健康检查(需要登录)""" return "Chat Get!" @chat.post("/") @@ -25,29 +68,9 @@ async def chat_post( query: str = Body(...), meta: dict = Body(None), history: list[dict] | None = Body(None), - thread_id: str | None = Body(None)): - """处理聊天请求的主要端点。 - Args: - query: 用户的输入查询文本 - meta: 包含请求元数据的字典,可以包含以下字段: - - use_web: 是否使用网络搜索 - - use_graph: 是否使用知识图谱 - - db_id: 数据库ID - - history_round: 历史对话轮数限制 - - system_prompt: 系统提示词(str,不含变量) - history: 对话历史记录列表 - thread_id: 对话线程ID - Returns: - StreamingResponse: 返回一个流式响应,包含以下状态: - - searching: 正在搜索知识库 - - generating: 正在生成回答 - - reasoning: 正在推理 - - loading: 正在加载回答 - - finished: 回答完成 - - error: 发生错误 - Raises: - HTTPException: 当检索器或模型发生错误时抛出 - """ + thread_id: str | None = Body(None), + current_user: User = Depends(get_required_user)): + """处理聊天请求的主要端点(需要登录)""" model = select_model() meta["server_model_name"] = model.model_name @@ -117,7 +140,8 @@ def generate_response(): return StreamingResponse(generate_response(), media_type='application/json') @chat.post("/call") -async def call(query: str = Body(...), meta: dict = Body(None)): +async def call(query: str = Body(...), meta: dict = Body(None), current_user: User = Depends(get_required_user)): + """调用模型进行简单问答(需要登录)""" meta = meta or {} model = select_model(model_provider=meta.get("model_provider"), model_name=meta.get("model_name")) async def predict_async(query): @@ -130,7 +154,8 @@ async def predict_async(query): return {"response": response.content} @chat.post("/call_lite") -async def call_lite(query: str = Body(...), meta: dict = Body(None)): +async def call_lite(query: str = Body(...), meta: dict = Body(None), current_user: User = Depends(get_required_user)): + """使用轻量版模型进行问答(需要登录)""" meta = meta or {} async def predict_async(query): loop = asyncio.get_event_loop() @@ -145,7 +170,8 @@ async def predict_async(query): return {"response": response.content} @chat.get("/agent") -async def get_agent(): +async def get_agent(current_user: User = Depends(get_required_user)): + """获取所有可用智能体(需要登录)""" agents = [agent.get_info() for agent in agent_manager.agents.values()] return {"agents": agents} @@ -154,12 +180,14 @@ def chat_agent(agent_name: str, query: str = Body(...), history: list = Body(...), config: dict = Body({}), - meta: dict = Body({})): + meta: dict = Body({}), + current_user: User = Depends(get_required_user)): + """使用特定智能体进行对话(需要登录)""" meta.update({ "query": query, "agent_name": agent_name, - "server_model_name": config.get("model", agent_name) , + "server_model_name": config.get("model", agent_name), "thread_id": config.get("thread_id"), }) @@ -223,19 +251,19 @@ def stream_messages(): return StreamingResponse(stream_messages(), media_type='application/json') @chat.get("/models") -async def get_chat_models(model_provider: str): - """获取指定模型提供商的模型列表""" +async def get_chat_models(model_provider: str, current_user: User = Depends(get_admin_user)): + """获取指定模型提供商的模型列表(需要登录)""" model = select_model(model_provider=model_provider) return {"models": model.get_models()} @chat.post("/models/update") -async def update_chat_models(model_provider: str, model_names: list[str]): - """更新指定模型提供商的模型列表""" +async def update_chat_models(model_provider: str, model_names: list[str], current_user = Depends(get_admin_user)): + """更新指定模型提供商的模型列表 (仅管理员)""" config.model_names[model_provider]["models"] = model_names config._save_models_to_file() return {"models": config.model_names[model_provider]["models"]} @chat.get("/tools") -async def get_tools(): - """获取所有工具""" +async def get_tools(current_user: User = Depends(get_admin_user)): + """获取所有可用工具(需要登录)""" return {"tools": list(get_all_tools().keys())} diff --git a/server/routers/data_router.py b/server/routers/data_router.py index 699ed9681..3b5534080 100644 --- a/server/routers/data_router.py +++ b/server/routers/data_router.py @@ -6,12 +6,14 @@ from src.utils import logger, hashstr from src import executor, retriever, config, knowledge_base, graph_base +from server.utils.auth_middleware import get_admin_user +from server.models.user_model import User data = APIRouter(prefix="/data") @data.get("/") -async def get_databases(): +async def get_databases(current_user: User = Depends(get_admin_user)): try: database = knowledge_base.get_databases() except Exception as e: @@ -23,7 +25,8 @@ async def get_databases(): async def create_database( database_name: str = Body(...), description: str = Body(...), - dimension: Optional[int] = Body(None) + dimension: Optional[int] = Body(None), + current_user: User = Depends(get_admin_user) ): logger.debug(f"Create database {database_name}") try: @@ -38,25 +41,25 @@ async def create_database( return database_info @data.delete("/") -async def delete_database(db_id): +async def delete_database(db_id, current_user: User = Depends(get_admin_user)): logger.debug(f"Delete database {db_id}") knowledge_base.delete_database(db_id) return {"message": "删除成功"} @data.post("/query-test") -async def query_test(query: str = Body(...), meta: dict = Body(...)): +async def query_test(query: str = Body(...), meta: dict = Body(...), current_user: User = Depends(get_admin_user)): logger.debug(f"Query test in {meta}: {query}") result = retriever.query_knowledgebase(query, history=None, refs={"meta": meta}) return result @data.post("/file-to-chunk") -async def file_to_chunk(files: List[str] = Body(...), params: dict = Body(...)): +async def file_to_chunk(files: List[str] = Body(...), params: dict = Body(...), current_user: User = Depends(get_admin_user)): logger.debug(f"File to chunk: {files}") result = knowledge_base.file_to_chunk(files, params=params) return result @data.post("/add-by-file") -async def create_document_by_file(db_id: str = Body(...), files: List[str] = Body(...)): +async def create_document_by_file(db_id: str = Body(...), files: List[str] = Body(...), current_user: User = Depends(get_admin_user)): logger.debug(f"Add document in {db_id} by file: {files}") try: # 使用线程池执行耗时操作 @@ -71,7 +74,7 @@ async def create_document_by_file(db_id: str = Body(...), files: List[str] = Bod return {"message": f"添加文件失败: {e}", "status": "failed"} @data.post("/add-by-chunks") -async def add_by_chunks(db_id: str = Body(...), file_chunks: dict = Body(...)): +async def add_by_chunks(db_id: str = Body(...), file_chunks: dict = Body(...), current_user: User = Depends(get_admin_user)): # logger.debug(f"Add chunks in {db_id}: {len(file_chunks)} chunks") try: loop = asyncio.get_event_loop() @@ -85,7 +88,7 @@ async def add_by_chunks(db_id: str = Body(...), file_chunks: dict = Body(...)): return {"message": f"添加分块失败: {e}", "status": "failed"} @data.get("/info") -async def get_database_info(db_id: str): +async def get_database_info(db_id: str, current_user: User = Depends(get_admin_user)): # logger.debug(f"Get database {db_id} info") database = knowledge_base.get_database_info(db_id) if database is None: @@ -93,13 +96,13 @@ async def get_database_info(db_id: str): return database @data.delete("/document") -async def delete_document(db_id: str = Body(...), file_id: str = Body(...)): +async def delete_document(db_id: str = Body(...), file_id: str = Body(...), current_user: User = Depends(get_admin_user)): logger.debug(f"DELETE document {file_id} info in {db_id}") knowledge_base.delete_file(db_id, file_id) return {"message": "删除成功"} @data.get("/document") -async def get_document_info(db_id: str, file_id: str): +async def get_document_info(db_id: str, file_id: str, current_user: User = Depends(get_admin_user)): logger.debug(f"GET document {file_id} info in {db_id}") try: @@ -113,7 +116,8 @@ async def get_document_info(db_id: str, file_id: str): @data.post("/upload") async def upload_file( file: UploadFile = File(...), - db_id: Optional[str] = Query(None) + db_id: Optional[str] = Query(None), + current_user: User = Depends(get_admin_user) ): if not file.filename: raise HTTPException(status_code=400, detail="No selected file") @@ -135,14 +139,14 @@ async def upload_file( return {"message": "File successfully uploaded", "file_path": file_path, "db_id": db_id} @data.get("/graph") -async def get_graph_info(): +async def get_graph_info(current_user: User = Depends(get_admin_user)): graph_info = graph_base.get_graph_info() if graph_info is None: raise HTTPException(status_code=400, detail="图数据库获取出错") return graph_info @data.post("/graph/index-nodes") -async def index_nodes(data: dict = Body(default={})): +async def index_nodes(data: dict = Body(default={}), current_user: User = Depends(get_admin_user)): if not graph_base.is_running(): raise HTTPException(status_code=400, detail="图数据库未启动") @@ -155,12 +159,12 @@ async def index_nodes(data: dict = Body(default={})): return {"status": "success", "message": f"已成功为{count}个节点添加嵌入向量", "indexed_count": count} @data.get("/graph/node") -async def get_graph_node(entity_name: str): +async def get_graph_node(entity_name: str, current_user: User = Depends(get_admin_user)): result = graph_base.query_node(entity_name=entity_name) return {"result": graph_base.format_query_result_to_graph(result), "message": "success"} @data.get("/graph/nodes") -async def get_graph_nodes(kgdb_name: str, num: int): +async def get_graph_nodes(kgdb_name: str, num: int, current_user: User = Depends(get_admin_user)): if not config.enable_knowledge_graph: raise HTTPException(status_code=400, detail="Knowledge graph is not enabled") @@ -169,7 +173,7 @@ async def get_graph_nodes(kgdb_name: str, num: int): return {"result": graph_base.format_general_results(result), "message": "success"} @data.post("/graph/add-by-jsonl") -async def add_graph_entity(file_path: str = Body(...), kgdb_name: Optional[str] = Body(None)): +async def add_graph_entity(file_path: str = Body(...), kgdb_name: Optional[str] = Body(None), current_user: User = Depends(get_admin_user)): if not config.enable_knowledge_graph: return {"message": "知识图谱未启用", "status": "failed"} diff --git a/server/routers/tool_router.py b/server/routers/tool_router.py deleted file mode 100644 index a587db7a0..000000000 --- a/server/routers/tool_router.py +++ /dev/null @@ -1,70 +0,0 @@ -import os -from fastapi import APIRouter, Body -from pydantic import BaseModel -from typing import List, Dict, Any, Optional - -from src.agents import agent_manager - -tool = APIRouter(prefix="/tool") - - -class Tool(BaseModel): - name: str - title: str - description: str - url: str - method: Optional[str] = "POST" - params: Optional[Dict[str, Any]] = None - metadata: Optional[Dict[str, Any]] = None - -@tool.get("/", response_model=List[Tool]) -async def route_index(): - tools = [ - Tool( - name="text-chunking", - title="文本分块", - description="将文本分块以更好地理解。可以输入文本或者上传文件。", - url="/tools/text-chunking", - method="POST", - ), - Tool( - name="pdf2txt", - title="PDF转文本", - description="将PDF文件转换为文本文件。", - url="/tools/pdf2txt", - method="POST", - ), - Tool( - name="agent", - title="智能体(Dev)", - description="智能体演练平台,现在还处于开发预览状态,欢迎提 Issue,但先不要用于正式场景。", - url="/tools/agent", - ) - ] - - for agent in agent_manager.agents.values(): - tools.append( - Tool( - name=agent.name, - title=agent.name, - description=agent.description, - url=f"/agent/{agent.name}", - method="POST", - metadata=agent.config_schema.to_dict(), - ) - ) - - return tools - -@tool.post("/text-chunking") -async def text_chunking(text: str = Body(...), params: Dict[str, Any] = Body(...)): - from src.core.indexing import chunk - nodes = chunk(text, params=params) - return {"nodes": [node.to_dict() for node in nodes]} - -@tool.post("/pdf2txt") -async def handle_pdf2txt(file: str = Body(...)): - from src.plugins import ocr - text = ocr.process_pdf(file) - return {"text": text} - diff --git a/server/utils/__init__.py b/server/utils/__init__.py new file mode 100644 index 000000000..4901b7189 --- /dev/null +++ b/server/utils/__init__.py @@ -0,0 +1 @@ +# utils包初始化文件 \ No newline at end of file diff --git a/server/utils/auth_middleware.py b/server/utils/auth_middleware.py new file mode 100644 index 000000000..6bc69d049 --- /dev/null +++ b/server/utils/auth_middleware.py @@ -0,0 +1,102 @@ +from typing import Optional, List, Callable +from fastapi import Depends, HTTPException, status, APIRouter +from fastapi.security import OAuth2PasswordBearer +from sqlalchemy.orm import Session +from jose import JWTError, jwt +import re + +from server.db_manager import db_manager +from server.models.user_model import User +from server.utils.auth_utils import AuthUtils + +# 定义OAuth2密码承载器,指定token URL +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token", auto_error=False) + +# 公开路径列表,无需登录即可访问 +PUBLIC_PATHS = [ + r"^/api/auth/token$", # 登录 + r"^/api/auth/check-first-run$", # 检查是否首次运行 + r"^/api/auth/initialize$", # 初始化系统 + r"^/api$", # Health Check + r"^/api/login$", # 登录页面 +] + +# 获取数据库会话 +def get_db(): + db = db_manager.get_session() + try: + yield db + finally: + db.close() + +# 获取当前用户 +async def get_current_user(token: Optional[str] = Depends(oauth2_scheme), db: Session = Depends(get_db)): + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="无效的凭证", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # 允许无token访问公开路径 + if token is None: + return None + + try: + # 验证token + payload = AuthUtils.verify_access_token(token) + user_id = payload.get("sub") + if user_id is None: + raise credentials_exception + except JWTError: + raise credentials_exception + except ValueError as e: + # 捕获AuthUtils.verify_access_token可能抛出的ValueError + # 例如令牌过期或无效 + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=str(e), # 将错误信息直接传递给客户端 + headers={"WWW-Authenticate": "Bearer"}, + ) + + # 查找用户 + user = db.query(User).filter(User.id == user_id).first() + if user is None: + raise credentials_exception + + return user + +# 获取已登录用户(抛出401如果未登录) +async def get_required_user(user: Optional[User] = Depends(get_current_user)): + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="请登录后再访问", + headers={"WWW-Authenticate": "Bearer"}, + ) + return user + +# 获取管理员用户 +async def get_admin_user(current_user: User = Depends(get_required_user)): + if current_user.role not in ["admin", "superadmin"]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="需要管理员权限", + ) + return current_user + +# 获取超级管理员用户 +async def get_superadmin_user(current_user: User = Depends(get_required_user)): + if current_user.role != "superadmin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="需要超级管理员权限", + ) + return current_user + +# 检查路径是否为公开路径 +def is_public_path(path: str) -> bool: + path = path.rstrip('/') # 去除尾部斜杠以便于匹配 + for pattern in PUBLIC_PATHS: + if re.match(pattern, path): + return True + return False diff --git a/server/utils/auth_utils.py b/server/utils/auth_utils.py new file mode 100644 index 000000000..18ce903a3 --- /dev/null +++ b/server/utils/auth_utils.py @@ -0,0 +1,75 @@ +import hashlib +import os +import jwt +from datetime import datetime, timedelta +from typing import Optional, Dict, Any + +# JWT配置 +JWT_SECRET_KEY = os.environ.get("JWT_SECRET_KEY", "yuxi_know_secure_key") +JWT_ALGORITHM = "HS256" +JWT_EXPIRATION = 24 * 60 * 60 # 24小时过期 + +class AuthUtils: + """认证工具类""" + + @staticmethod + def hash_password(password: str) -> str: + """使用SHA-256哈希密码""" + # 生成盐 + salt = os.urandom(32).hex() + # 哈希密码 + hashed = hashlib.sha256((password + salt).encode()).hexdigest() + # 返回格式: "哈希值:盐" + return f"{hashed}:{salt}" + + @staticmethod + def verify_password(stored_password: str, provided_password: str) -> bool: + """验证密码""" + # 分离哈希值和盐 + if ":" not in stored_password: + return False + + hashed, salt = stored_password.split(":") + + # 使用相同的盐哈希提供的密码 + check_hash = hashlib.sha256((provided_password + salt).encode()).hexdigest() + + # 比较哈希值 + return hashed == check_hash + + @staticmethod + def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: + """创建JWT访问令牌""" + to_encode = data.copy() + + # 设置过期时间 + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(seconds=JWT_EXPIRATION) + + to_encode.update({"exp": expire}) + + # 编码JWT + encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) + return encoded_jwt + + @staticmethod + def decode_token(token: str) -> Optional[Dict[str, Any]]: + """解码验证JWT令牌""" + try: + payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) + return payload + except jwt.PyJWTError: + return None + + @staticmethod + def verify_access_token(token: str) -> Dict[str, Any]: + """验证访问令牌,如果无效则抛出异常""" + try: + payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) + return payload + except jwt.ExpiredSignatureError: + raise ValueError("令牌已过期") + except jwt.InvalidTokenError: + raise ValueError("无效的令牌") \ No newline at end of file diff --git a/src/config/__init__.py b/src/config/__init__.py index b0df13178..6bc14fdda 100644 --- a/src/config/__init__.py +++ b/src/config/__init__.py @@ -29,6 +29,10 @@ def __setitem__(self, key, value): def __dict__(self): return {k: v for k, v in self.items()} + def update(self, other): + for key, value in other.items(): + self[key] = value + class Config(SimpleConfig): @@ -47,6 +51,8 @@ def __init__(self): self.add_item("enable_knowledge_base", default=False, des="是否开启知识库") self.add_item("enable_knowledge_graph", default=False, des="是否开启知识图谱") self.add_item("enable_web_search", default=False, des="是否开启网页搜索(注:现阶段会根据 TAVILY_API_KEY 自动开启,无法手动配置,将会在下个版本移除此配置项)") + # 默认智能体配置 + self.add_item("default_agent_id", default="", des="默认智能体ID") # 模型配置 ## 注意这里是模型名,而不是具体的模型路径,默认使用 HuggingFace 的路径 ## 如果需要自定义本地模型路径,则在 src/.env 中配置 MODEL_DIR @@ -212,18 +218,8 @@ def save(self): logger.info(f"Config file {self.filename} saved") - def get_safe_config(self): - """ - 获取安全的配置,即过滤掉 api_key - """ - - config = json.loads(str(self)) - - # 过滤掉 api_key - for model in config.get("custom_models", []): - model["api_key"] = DEFAULT_MOCK_API if model.get("api_key") else "" - - return config + def dump_config(self): + return json.loads(str(self)) def compare_custom_models(self, value): """ diff --git a/src/core/history.py b/src/core/history.py index 277a8d385..6e5c597e3 100644 --- a/src/core/history.py +++ b/src/core/history.py @@ -3,11 +3,14 @@ class HistoryManager(): def __init__(self, history=None, system_prompt=None): - self.messages = history or [] + self.messages = [] system_prompt = system_prompt or get_system_prompt() self.add_system(system_prompt) + if history: + self.messages.extend(history) + def add(self, role, content): self.messages.append({"role": role, "content": content}) return self.messages diff --git a/web/src/apis/admin_api.js b/web/src/apis/admin_api.js new file mode 100644 index 000000000..102656343 --- /dev/null +++ b/web/src/apis/admin_api.js @@ -0,0 +1,293 @@ +import { apiGet, apiPost, apiPut, apiDelete } from './base' +import { useUserStore } from '@/stores/user' + +/** + * 管理员API模块 + * 只有管理员和超级管理员可以访问的API + * 权限要求: admin 或 superadmin + * + * 注意: 请确保在使用这些API之前检查用户是否具有管理员权限 + */ + +// 检查当前用户是否有管理员权限 +const checkAdminPermission = () => { + const userStore = useUserStore() + if (!userStore.isAdmin) { + throw new Error('需要管理员权限') + } + return true +} + +// 检查当前用户是否有超级管理员权限 +const checkSuperAdminPermission = () => { + const userStore = useUserStore() + if (!userStore.isSuperAdmin) { + throw new Error('需要超级管理员权限') + } + return true +} + +// 用户管理API +export const userManagementApi = { + /** + * 获取用户列表 + * @returns {Promise} - 用户列表 + */ + getUsers: async () => { + checkAdminPermission() + return apiGet('/api/auth/users', {}, true) + }, + + /** + * 创建新用户 + * @param {Object} userData - 用户数据 + * @returns {Promise} - 创建结果 + */ + createUser: async (userData) => { + checkAdminPermission() + return apiPost('/api/auth/users', userData, {}, true) + }, + + /** + * 更新用户 + * @param {number} userId - 用户ID + * @param {Object} userData - 用户数据 + * @returns {Promise} - 更新结果 + */ + updateUser: async (userId, userData) => { + checkAdminPermission() + return apiPut(`/api/auth/users/${userId}`, userData, {}, true) + }, + + /** + * 删除用户 + * @param {number} userId - 用户ID + * @returns {Promise} - 删除结果 + */ + deleteUser: async (userId) => { + checkAdminPermission() + return apiDelete(`/api/auth/users/${userId}`, {}, true) + }, +} + +// 知识库管理API +export const knowledgeBaseApi = { + /** + * 获取所有知识库 + * @returns {Promise} - 知识库列表 + */ + getDatabases: async () => { + checkAdminPermission() + return apiGet('/api/data/', {}, true) + }, + + /** + * 创建知识库 + * @param {Object} databaseData - 知识库数据 + * @returns {Promise} - 创建结果 + */ + createDatabase: async (databaseData) => { + checkAdminPermission() + return apiPost('/api/data/', databaseData, {}, true) + }, + + /** + * 获取知识库详情 + * @param {string} dbId - 知识库ID + * @returns {Promise} - 知识库详情 + */ + getDatabaseInfo: async (dbId) => { + checkAdminPermission() + return apiGet(`/api/data/info?db_id=${dbId}`, {}, true) + }, + + /** + * 删除知识库 + * @param {string} dbId - 知识库ID + * @returns {Promise} - 删除结果 + */ + deleteDatabase: async (dbId) => { + checkAdminPermission() + return apiDelete(`/api/data/?db_id=${dbId}`, {}, true) + }, + + /** + * 上传文件到知识库 + * @param {FormData} formData - 包含文件的FormData + * @param {string} dbId - 知识库ID + * @returns {Promise} - 上传结果 + */ + uploadFile: async (formData, dbId) => { + checkAdminPermission() + return fetch(`/api/data/upload?db_id=${dbId}`, { + method: 'POST', + headers: { + ...useUserStore().getAuthHeaders() + }, + body: formData + }).then(res => res.json()) + }, + + /** + * 删除文件 + * @param {string} dbId - 知识库ID + * @param {string} fileId - 文件ID + * @returns {Promise} - 删除结果 + */ + deleteFile: async (dbId, fileId) => { + checkAdminPermission() + return apiDelete('/api/data/document', { + body: JSON.stringify({ db_id: dbId, file_id: fileId }) + }, true) + }, + + /** + * 将文件分块 + * @param {Object} data - 分块参数 + * @returns {Promise} - 分块结果 + */ + fileToChunk: async (data) => { + checkAdminPermission() + return apiPost('/api/data/file-to-chunk', data, {}, true) + }, + + /** + * 将分块添加到数据库 + * @param {Object} data - 包含db_id和file_chunks的数据 + * @returns {Promise} - 添加结果 + */ + addByChunks: async (data) => { + checkAdminPermission() + return apiPost('/api/data/add-by-chunks', data, {}, true) + }, + + /** + * 查询测试 + * @param {Object} data - 查询参数 + * @returns {Promise} - 查询结果 + */ + queryTest: async (data) => { + checkAdminPermission() + return apiPost('/api/data/query-test', data, {}, true) + }, + + /** + * 获取文档详情 + * @param {string} dbId - 知识库ID + * @param {string} fileId - 文件ID + * @returns {Promise} - 文档详情 + */ + getDocumentDetail: async (dbId, fileId) => { + checkAdminPermission() + return apiGet(`/api/data/document?db_id=${dbId}&file_id=${fileId}`, {}, true) + } +} + +// 图数据库管理API +export const graphApi = { + /** + * 获取图数据库状态 + * @returns {Promise} - 图数据库状态 + */ + getGraphInfo: async () => { + checkAdminPermission() + return apiGet('/api/data/graph', {}, true) + }, + + /** + * 获取节点 + * @param {string} dbName - 图数据库名称 + * @param {number} num - 节点数量 + * @returns {Promise} - 节点数据 + */ + getNodes: async (dbName, num) => { + checkAdminPermission() + return apiGet(`/api/data/graph/nodes?kgdb_name=${dbName}&num=${num}`, {}, true) + }, + + /** + * 查询实体 + * @param {string} entityName - 实体名称 + * @returns {Promise} - 查询结果 + */ + queryNode: async (entityName) => { + checkAdminPermission() + return apiGet(`/api/data/graph/node?entity_name=${entityName}`, {}, true) + }, + + /** + * 添加JSONL文件到图数据库 + * @param {string} filePath - 文件路径 + * @returns {Promise} - 添加结果 + */ + addByJsonl: async (filePath) => { + checkAdminPermission() + return apiPost('/api/data/graph/add-by-jsonl', { file_path: filePath }, {}, true) + }, + + /** + * 为未索引节点添加索引 + * @param {string} dbName - 图数据库名称 + * @returns {Promise} - 索引结果 + */ + indexNodes: async (dbName) => { + checkAdminPermission() + return apiPost('/api/data/graph/index-nodes', { kgdb_name: dbName }, {}, true) + }, +} + +// 系统配置API +export const systemConfigApi = { + /** + * 设置默认智能体 + * @param {string} agentId - 智能体ID + * @returns {Promise} - 设置结果 + */ + setDefaultAgent: async (agentId) => { + checkAdminPermission() + return apiPost('/api/chat/set_default_agent', { agent_id: agentId }, {}, true) + }, + + /** + * 获取系统配置 + * @returns {Promise} - 系统配置 + */ + getSystemConfig: async () => { + checkAdminPermission() + return apiGet('/api/config', {}, true) + }, + + + /** + * 更新某个配置 + * @param {Object} items - 配置项 + * @returns {Promise} - 更新结果 + */ + updateConfigItems: async (items) => { + checkAdminPermission() + console.log("updateConfigItems", items) + return apiPost('/api/config/update', items, {}, true) + }, + + /** + * 重启服务 + * @returns {Promise} - 重启结果 + */ + restartServer: async () => { + checkSuperAdminPermission() + return apiPost('/api/restart', {}, {}, true) + } +} + +// 日志API +export const logApi = { + /** + * 获取系统日志 + * @param {Object} params - 日志查询参数 + * @returns {Promise} - 日志数据 + */ + getLogs: async (params = {}) => { + checkAdminPermission() + return apiGet('/api/log', { params }, true) + }, +} \ No newline at end of file diff --git a/web/src/apis/auth_api.js b/web/src/apis/auth_api.js new file mode 100644 index 000000000..1e2f34b7b --- /dev/null +++ b/web/src/apis/auth_api.js @@ -0,0 +1,143 @@ +import { apiGet, apiPost, apiDelete } from './base' +import { useUserStore } from '@/stores/user' + +/** + * 需要用户认证的API模块 + * 用户必须登录才能访问的API + * 权限要求: 任何已登录用户(普通用户、管理员、超级管理员) + */ + +// 聊天相关API +export const chatApi = { + /** + * 发送聊天消息 + * @param {Object} params - 聊天参数 + * @returns {Promise} - 聊天响应流 + */ + sendMessage: (params) => { + return fetch('/api/chat/', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + ...useUserStore().getAuthHeaders() + }, + body: JSON.stringify(params), + }) + }, + + /** + * 发送可中断的聊天消息 + * @param {Object} params - 聊天参数 + * @param {AbortSignal} signal - 用于中断请求的信号控制器 + * @returns {Promise} - 聊天响应流 + */ + sendMessageWithAbort: (params, signal) => { + return fetch('/api/chat/', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + ...useUserStore().getAuthHeaders() + }, + body: JSON.stringify(params), + signal // 添加 signal 用于中断请求 + }) + }, + + /** + * 发送聊天消息到指定智能体(流式响应) + * @param {string} agentId - 智能体ID + * @param {Object} data - 聊天数据 + * @returns {Promise} - 聊天响应流 + */ + sendAgentMessage: (agentId, data) => { + return fetch(`/api/chat/agent/${agentId}`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + ...useUserStore().getAuthHeaders() + }, + body: JSON.stringify(data) + }) + }, + + /** + * 简单聊天调用(非流式) + * @param {string} query - 查询内容 + * @returns {Promise} - 聊天响应 + */ + simpleCall: (query) => apiPost('/api/chat/call', { query }, {}, true), + + /** + * 获取默认智能体 + * @returns {Promise} - 默认智能体信息 + */ + getDefaultAgent: () => apiGet('/api/chat/default_agent', {}, true), + + /** + * 获取智能体列表 + * @returns {Promise} - 智能体列表 + */ + getAgents: () => apiGet('/api/chat/agent', {}, true), + + /** + * 获取单个智能体详情 + * @param {string} agentId - 智能体ID + * @returns {Promise} - 智能体详情 + */ + getAgentDetail: (agentId) => apiGet(`/api/chat/agent/${agentId}`, {}, true), + + /** + * 获取可用工具列表 + * @returns {Promise} - 工具列表 + */ + getTools: () => apiGet('/api/chat/tools', {}, true), + + // /** + // * 获取对话历史 + // * @returns {Promise} - 对话历史列表 + // */ + // getConversations: () => apiGet('/api/chat/conversations', {}, true), + + // /** + // * 获取特定对话 + // * @param {string} conversationId - 对话ID + // * @returns {Promise} - 对话详情 + // */ + // getConversation: (conversationId) => + // apiGet(`/api/chat/conversations/${conversationId}`, {}, true), + + // /** + // * 删除对话 + // * @param {string} conversationId - 对话ID + // * @returns {Promise} - 删除结果 + // */ + // deleteConversation: (conversationId) => + // apiDelete(`/api/chat/conversations/${conversationId}`, {}, true), + + // /** + // * 更新对话标题 + // * @param {string} conversationId - 对话ID + // * @param {string} title - 新标题 + // * @returns {Promise} - 更新结果 + // */ + // updateConversationTitle: (conversationId, title) => + // apiPost(`/api/chat/conversations/${conversationId}/title`, { title }, {}, true), +} + +// 用户设置API +export const userSettingsApi = { + /** + * 获取用户设置 + * @returns {Promise} - 用户设置 + */ + getSettings: () => apiGet('/api/user/settings', {}, true), + + /** + * 更新用户设置 + * @param {Object} settings - 新设置 + * @returns {Promise} - 更新结果 + */ + updateSettings: (settings) => apiPost('/api/user/settings', settings, {}, true), +} + +// 其他需要用户认证的API可以继续添加到这里 \ No newline at end of file diff --git a/web/src/apis/base.js b/web/src/apis/base.js new file mode 100644 index 000000000..9e6a7eee0 --- /dev/null +++ b/web/src/apis/base.js @@ -0,0 +1,154 @@ +import { useUserStore } from '@/stores/user' +import { message } from 'ant-design-vue' + +/** + * 基础API请求封装 + * 提供统一的请求方法,自动处理认证头和错误 + */ + +/** + * 发送API请求的基础函数 + * @param {string} url - API端点 + * @param {Object} options - 请求选项 + * @param {boolean} requiresAuth - 是否需要认证头 + * @returns {Promise} - 请求结果 + */ +export async function apiRequest(url, options = {}, requiresAuth = false) { + try { + // 默认请求配置 + const requestOptions = { + ...options, + headers: { + 'Content-Type': 'application/json', + ...options.headers, + }, + } + + // 如果需要认证,添加认证头 + if (requiresAuth) { + const userStore = useUserStore() + if (!userStore.isLoggedIn) { + throw new Error('用户未登录') + } + + Object.assign(requestOptions.headers, userStore.getAuthHeaders()) + } + + // 发送请求 + const response = await fetch(url, requestOptions) + + // 处理API返回的错误 + if (!response.ok) { + // 尝试解析错误信息 + let errorMessage = `请求失败: ${response.status}` + let errorData = null + + try { + errorData = await response.json() + errorMessage = errorData.detail || errorData.message || errorMessage + } catch (e) { + // 如果无法解析JSON,使用默认错误信息 + } + + // 特殊处理401和403错误 + if (response.status === 401) { + // 如果是认证失败,可能需要重新登录 + const userStore = useUserStore() + if (userStore.isLoggedIn) { + // 如果用户认为自己已登录,但收到401,则可能是令牌过期 + const isTokenExpired = errorData && + (errorData.detail?.includes('令牌已过期') || + errorData.detail?.includes('token expired') || + errorMessage?.includes('令牌已过期') || + errorMessage?.includes('token expired')) + + message.error(isTokenExpired ? '登录已过期,请重新登录' : '认证失败,请重新登录') + userStore.logout() + + // 使用setTimeout确保消息显示后再跳转 + setTimeout(() => { + window.location.href = '/login' + }, 1500) + } + throw new Error('未授权,请先登录') + } else if (response.status === 403) { + throw new Error('没有权限执行此操作') + } + + throw new Error(errorMessage) + } + + // 检查Content-Type以确定如何处理响应 + const contentType = response.headers.get('Content-Type') + if (contentType && contentType.includes('application/json')) { + return await response.json() + } + + return await response.text() + } catch (error) { + console.error('API请求错误:', error) + throw error + } +} + +/** + * 发送GET请求 + * @param {string} url - API端点 + * @param {Object} options - 请求选项 + * @param {boolean} requiresAuth - 是否需要认证 + * @returns {Promise} - 请求结果 + */ +export function apiGet(url, options = {}, requiresAuth = false) { + return apiRequest(url, { method: 'GET', ...options }, requiresAuth) +} + +/** + * 发送POST请求 + * @param {string} url - API端点 + * @param {Object} data - 请求体数据 + * @param {Object} options - 其他请求选项 + * @param {boolean} requiresAuth - 是否需要认证 + * @returns {Promise} - 请求结果 + */ +export function apiPost(url, data = {}, options = {}, requiresAuth = false) { + return apiRequest( + url, + { + method: 'POST', + body: JSON.stringify(data), + ...options + }, + requiresAuth + ) +} + +/** + * 发送PUT请求 + * @param {string} url - API端点 + * @param {Object} data - 请求体数据 + * @param {Object} options - 其他请求选项 + * @param {boolean} requiresAuth - 是否需要认证 + * @returns {Promise} - 请求结果 + */ +export function apiPut(url, data = {}, options = {}, requiresAuth = false) { + return apiRequest( + url, + { + method: 'PUT', + body: JSON.stringify(data), + ...options + }, + requiresAuth + ) +} + +/** + * 发送DELETE请求 + * @param {string} url - API端点 + * @param {Object} options - 请求选项 + * @param {boolean} requiresAuth - 是否需要认证 + * @returns {Promise} - 请求结果 + */ +export function apiDelete(url, options = {}, requiresAuth = false) { + return apiRequest(url, { method: 'DELETE', ...options }, requiresAuth) +} \ No newline at end of file diff --git a/web/src/apis/index.js b/web/src/apis/index.js new file mode 100644 index 000000000..6acfb3fc0 --- /dev/null +++ b/web/src/apis/index.js @@ -0,0 +1,33 @@ +/** + * API模块索引文件 + * 导出所有API模块,方便统一引入 + */ + +// 导出公共API模块 +export * from './public_api' + +// 导出需要用户认证的API模块 +export * from './auth_api' + +// 导出需要管理员权限的API模块 +export * from './admin_api' + +// 导出基础工具函数 +export { apiRequest, apiGet, apiPost, apiPut, apiDelete } from './base' + +/** + * 权限说明: + * + * 1. public_api.js: 不需要认证就可以访问的API + * - 登录、初始化管理员、获取公共配置等 + * + * 2. auth_api.js: 需要用户认证才能访问的API + * - 权限要求: 任何已登录用户(普通用户、管理员、超级管理员) + * - 聊天功能、个人设置等 + * + * 3. admin_api.js: 需要管理员权限才能访问的API + * - 权限要求: admin 或 superadmin + * - 用户管理、知识库管理、系统配置等 + * + * 注意:本模块已处理权限验证和请求头,使用时无需再手动添加认证头 + */ \ No newline at end of file diff --git a/web/src/apis/public_api.js b/web/src/apis/public_api.js new file mode 100644 index 000000000..66cb04122 --- /dev/null +++ b/web/src/apis/public_api.js @@ -0,0 +1,59 @@ +import { apiGet, apiPost } from './base' + +/** + * 公共API模块 + * 包含所有不需要认证的公共接口 + */ + +// 登录相关API +export const authApi = { + /** + * 用户登录 + * @param {Object} credentials - 登录凭证 + * @returns {Promise} - 登录结果 + */ + login: (credentials) => { + const formData = new FormData() + formData.append('username', credentials.username) + formData.append('password', credentials.password) + + return apiRequest('/api/auth/token', { + method: 'POST', + body: formData + }, false) + }, + + /** + * 检查是否是首次运行 + * @returns {Promise} - 是否首次运行 + */ + checkFirstRun: () => apiGet('/api/auth/check-first-run'), + + /** + * 初始化管理员账户 + * @param {Object} adminData - 管理员账户数据 + * @returns {Promise} - 初始化结果 + */ + initializeAdmin: (adminData) => apiPost('/api/auth/initialize', adminData), +} + +// 配置相关API +export const configApi = { + /** + * 获取系统配置 + * @returns {Promise} - 系统配置 + */ + getConfig: () => apiGet('/api/config'), +} + +// 健康检查API +export const healthApi = { + /** + * 系统健康检查 + * @returns {Promise} - 健康检查结果 + */ + check: () => apiGet('/api/health'), +} + +// 从base.js导入apiRequest以支持FormData +import { apiRequest } from './base' \ No newline at end of file diff --git a/web/src/components/AgentChatComponent.vue b/web/src/components/AgentChatComponent.vue index 5272d74be..33615a87d 100644 --- a/web/src/components/AgentChatComponent.vue +++ b/web/src/components/AgentChatComponent.vue @@ -12,11 +12,11 @@
-
@@ -110,6 +110,7 @@ import { import { message } from 'ant-design-vue'; import MessageInputComponent from '@/components/MessageInputComponent.vue' import MessageComponent from '@/components/MessageComponent.vue' +import { chatApi } from '@/apis/auth_api' // 新增props属性,允许父组件传入agentId const props = defineProps({ @@ -354,11 +355,7 @@ const sendMessageWithText = async (text) => { }; // 发送请求 - const response = await fetch(`/api/chat/agent/${currentAgent.value.name}`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify(requestData) - }); + const response = await chatApi.sendAgentMessage(currentAgent.value.name, requestData); // console.log("requestData", requestData); if (!response.ok) { @@ -752,18 +749,13 @@ const appendToolMessageToExistingAssistant = async (data) => { // 获取智能体列表 const fetchAgents = async () => { try { - const response = await fetch('/api/chat/agent'); - if (response.ok) { - const data = await response.json(); - // 将数组转换为对象 - agents.value = data.agents.reduce((acc, agent) => { - acc[agent.name] = agent; - return acc; - }, {}); - console.log("agents", agents.value); - } else { - console.error('获取智能体失败'); - } + const data = await chatApi.getAgents(); + // 将数组转换为对象 + agents.value = data.agents.reduce((acc, agent) => { + acc[agent.name] = agent; + return acc; + }, {}); + console.log("agents", agents.value); } catch (error) { console.error('获取智能体错误:', error); } diff --git a/web/src/components/ChatComponent.vue b/web/src/components/ChatComponent.vue index 6ef374008..546ce5dd6 100644 --- a/web/src/components/ChatComponent.vue +++ b/web/src/components/ChatComponent.vue @@ -184,10 +184,13 @@ import { } from '@ant-design/icons-vue' import { onClickOutside } from '@vueuse/core' import { useConfigStore } from '@/stores/config' +import { useUserStore } from '@/stores/user' import { message } from 'ant-design-vue' import MessageInputComponent from '@/components/MessageInputComponent.vue' import MessageComponent from '@/components/MessageComponent.vue' import RefsSidebar from '@/components/RefsSidebar.vue' +import { chatApi } from '@/apis/auth_api' +import { knowledgeBaseApi } from '@/apis/admin_api' const props = defineProps({ conv: Object, @@ -196,6 +199,7 @@ const props = defineProps({ const emit = defineEmits(['rename-title', 'newconv']); const configStore = useConfigStore() +const userStore = useUserStore() const { conv, state } = toRefs(props) const chatContainer = ref(null) @@ -475,50 +479,69 @@ const groupRefs = (id) => { scrollToBottom() } +const loadDatabases = () => { + // 由于这是管理功能,需要检查用户是否有管理权限 + if (!userStore.isAdmin) { + console.log('非管理员用户,跳过加载数据库列表'); + return; + } + + try { + knowledgeBaseApi.getDatabases() + .then(data => { + console.log(data) + opts.databases = data.databases + }) + .catch(error => { + console.error('加载数据库列表失败:', error) + }) + } catch (error) { + console.error('获取数据库列表失败:', error); + } +} + const simpleCall = (msg) => { return new Promise((resolve, reject) => { - fetch('/api/chat/call', { - method: 'POST', - body: JSON.stringify({ query: msg, }), - headers: { 'Content-Type': 'application/json' } - }) - .then((response) => response.json()) - .then((data) => resolve(data)) - .catch((error) => reject(error)) + chatApi.simpleCall(msg) + .then(data => resolve(data)) + .catch(error => reject(error)) }) } -const loadDatabases = () => { - fetch('/api/data/', { method: "GET", }) - .then(response => response.json()) - .then(data => { - console.log(data) - opts.databases = data.databases - }) -} - -// 新函数用于处理 fetch 请求 +// 替换fetchChatResponse函数 const fetchChatResponse = (user_input, cur_res_id) => { const controller = new AbortController(); const signal = controller.signal; const params = { query: user_input, - history: getHistory().slice(0, -1), // 去掉最后一条刚添加的用户消息, + history: getHistory().slice(0, -1), // 去掉最后一条刚添加的用户消息 meta: meta, cur_res_id: cur_res_id, } console.log(params) - fetch('/api/chat/', { - method: 'POST', - body: JSON.stringify(params), - headers: { - 'Content-Type': 'application/json' - }, - signal // 添加 signal 用于中断请求 - }) + // 使用API函数发送请求 + chatApi.sendMessageWithAbort(params, signal) .then((response) => { + if (!response.ok) { + // 检查是否是401错误(令牌过期) + if (response.status === 401) { + const userStore = useUserStore(); + if (userStore.isLoggedIn) { + message.error('登录已过期,请重新登录'); + userStore.logout(); + + // 使用setTimeout确保消息显示后再跳转 + setTimeout(() => { + window.location.href = '/login'; + }, 1500); + } + throw new Error('未授权,请先登录'); + } + throw new Error(`请求失败: ${response.status} ${response.statusText}`); + } + if (!response.body) throw new Error("ReadableStream not supported."); const reader = response.body.getReader(); const decoder = new TextDecoder("utf-8"); @@ -559,8 +582,6 @@ const fetchChatResponse = (user_input, cur_res_id) => { meta: data.meta, ...data, }); - // console.log("Last message", conv.value.messages[conv.value.messages.length - 1].content) - // console.log("Last message", conv.value.messages[conv.value.messages.length - 1].status) if (data.history) { conv.value.history = data.history; @@ -583,11 +604,18 @@ const fetchChatResponse = (user_input, cur_res_id) => { if (error.name === 'AbortError') { console.log('Fetch aborted'); } else { - console.error(error); - updateMessage({ - id: cur_res_id, - status: "error", - }); + console.error('聊天请求错误:', error); + + // 检查是否是认证错误 + if (error.message.includes('未授权') || error.message.includes('令牌已过期')) { + // 已在上面处理,这里不需要重复处理 + } else { + updateMessage({ + id: cur_res_id, + status: "error", + message: error.message || '请求失败', + }); + } } isStreaming.value = false; }); diff --git a/web/src/components/DebugComponent.vue b/web/src/components/DebugComponent.vue index 07b0d9eee..d7bb3640f 100644 --- a/web/src/components/DebugComponent.vue +++ b/web/src/components/DebugComponent.vue @@ -22,7 +22,7 @@ {{ state.isFullscreen ? '退出全屏' : '全屏' }} - @@ -72,9 +72,11 @@ diff --git a/web/src/components/TokenManagerComponent.vue b/web/src/components/TokenManagerComponent.vue deleted file mode 100644 index 4b074f859..000000000 --- a/web/src/components/TokenManagerComponent.vue +++ /dev/null @@ -1,258 +0,0 @@ - - - - - \ No newline at end of file diff --git a/web/src/components/UserInfoComponent.vue b/web/src/components/UserInfoComponent.vue new file mode 100644 index 000000000..e22e36e9b --- /dev/null +++ b/web/src/components/UserInfoComponent.vue @@ -0,0 +1,162 @@ + + + + + \ No newline at end of file diff --git a/web/src/components/tools/ConvertToTxtComponent.vue b/web/src/components/tools/ConvertToTxtComponent.vue deleted file mode 100644 index a95ce38cf..000000000 --- a/web/src/components/tools/ConvertToTxtComponent.vue +++ /dev/null @@ -1,183 +0,0 @@ - - - - - diff --git a/web/src/components/tools/TextChunkingComponent.vue b/web/src/components/tools/TextChunkingComponent.vue deleted file mode 100644 index d06c5dfd6..000000000 --- a/web/src/components/tools/TextChunkingComponent.vue +++ /dev/null @@ -1,276 +0,0 @@ - - - - - \ No newline at end of file diff --git a/web/src/layouts/AppLayout.vue b/web/src/layouts/AppLayout.vue index e29115ce1..6e325463f 100644 --- a/web/src/layouts/AppLayout.vue +++ b/web/src/layouts/AppLayout.vue @@ -29,6 +29,7 @@ import { themeConfig } from '@/assets/theme' import { useConfigStore } from '@/stores/config' import { useDatabaseStore } from '@/stores/database' import DebugComponent from '@/components/DebugComponent.vue' +import UserInfoComponent from '@/components/UserInfoComponent.vue' const configStore = useConfigStore() const databaseStore = useDatabaseStore() @@ -57,11 +58,12 @@ const getRemoteDatabase = () => { const fetchGithubStars = async () => { try { isLoadingStars.value = true + // 公共API,可以直接使用fetch const response = await fetch('https://api.github.com/repos/xerrors/Yuxi-Know') const data = await response.json() githubStars.value = data.stargazers_count } catch (error) { - console.error('Error fetching GitHub stars:', error) + console.error('获取GitHub stars失败:', error) } finally { isLoadingStars.value = false } @@ -77,18 +79,17 @@ onMounted(() => { const route = useRoute() console.log(route) -const apiDocsUrl = computed(() => { - // return `${import.meta.env.VITE_API_URL || `http://${window.location.hostname}:${window.location.port}`}/docs` - return `http://localhost:5050/docs` -}) - - // 下面是导航菜单部分,添加智能体项 const mainList = [{ name: '对话', path: '/chat', icon: MessageOutlined, activeIcon: MessageFilled, + }, { + name: '智能体', + path: '/agent', + icon: ToolOutlined, + activeIcon: ToolFilled, }, { name: '图谱', path: '/graph', @@ -101,11 +102,6 @@ const mainList = [{ icon: BookOutlined, activeIcon: BookFilled, // hidden: !configStore.config.enable_knowledge_base, - }, { - name: '工具', - path: '/tools', - icon: ToolOutlined, - activeIcon: ToolFilled, } ] @@ -163,25 +159,36 @@ const mainList = [{
+ + - - - - - @@ -213,25 +209,30 @@ import { CloseOutlined, SettingOutlined, KeyOutlined, - LinkOutlined + LinkOutlined, + StarOutlined, + StarFilled } from '@ant-design/icons-vue'; import { message } from 'ant-design-vue'; import AgentChatComponent from '@/components/AgentChatComponent.vue'; -import TokenManagerComponent from '@/components/TokenManagerComponent.vue'; +import { useUserStore } from '@/stores/user'; +import { chatApi } from '@/apis/auth_api'; +import { systemConfigApi } from '@/apis/admin_api'; // 路由 const router = useRouter(); +const userStore = useUserStore(); // 状态 const agents = ref({}); const selectedAgentId = ref(null); const availableTools = ref([]); // 存储所有可用的工具列表 +const defaultAgentId = ref(null); // 存储默认智能体ID const state = reactive({ debug_mode: false, isSidebarOpen: JSON.parse(localStorage.getItem('agent-sidebar-open') || 'true'), isConfigSidebarOpen: false, configModalVisible: false, - tokenModalVisible: false, isEmptyConfig: computed(() => !selectedAgentId.value || Object.keys(configurableItems.value).length === 0 @@ -243,50 +244,50 @@ const configurableItems = computed(() => configSchema.value.configurable_items | // 配置状态 const agentConfig = ref({}); -// 调试模式 -const toggleDebugMode = () => { - state.debug_mode = !state.debug_mode; -}; - -// 打开配置弹窗 -const openConfigModal = () => { - state.configModalVisible = true; -}; +// 检查是否为默认智能体 +const isDefaultAgent = computed(() => { + return selectedAgentId.value === defaultAgentId.value; +}); -// 关闭配置弹窗 -const closeConfigModal = () => { - state.configModalVisible = false; -}; +// 设置为默认智能体 +const setAsDefaultAgent = async () => { + if (!selectedAgentId.value || !userStore.isAdmin) return; -// 打开令牌管理弹窗 -const openTokenModal = () => { - state.tokenModalVisible = true; + try { + await systemConfigApi.setDefaultAgent(selectedAgentId.value); + defaultAgentId.value = selectedAgentId.value; + message.success('已将当前智能体设为默认'); + } catch (error) { + console.error('设置默认智能体错误:', error); + message.error(error.message || '设置默认智能体时发生错误'); + } }; -// 关闭令牌管理弹窗 -const closeTokenModal = () => { - state.tokenModalVisible = false; +// 获取默认智能体ID +const fetchDefaultAgent = async () => { + try { + const data = await chatApi.getDefaultAgent(); + defaultAgentId.value = data.default_agent_id; + console.log("Default agent ID:", defaultAgentId.value); + } catch (error) { + console.error('获取默认智能体错误:', error); + } }; // 获取智能体列表 const fetchAgents = async () => { try { - const response = await fetch('/api/chat/agent'); - if (response.ok) { - const data = await response.json(); - // 将数组转换为对象 - agents.value = data.agents.reduce((acc, agent) => { - acc[agent.name] = agent; - return acc; - }, {}); - // console.log("agents", agents.value); - - // 加载当前选中智能体的配置 - if (selectedAgentId.value) { - loadAgentConfig(); - } - } else { - console.error('获取智能体失败'); + const data = await chatApi.getAgents(); + // 将数组转换为对象 + agents.value = data.agents.reduce((acc, agent) => { + acc[agent.name] = agent; + return acc; + }, {}); + // console.log("agents", agents.value); + + // 加载当前选中智能体的配置 + if (selectedAgentId.value) { + loadAgentConfig(); } } catch (error) { console.error('获取智能体错误:', error); @@ -296,14 +297,9 @@ const fetchAgents = async () => { // 获取所有可用工具 const fetchTools = async () => { try { - const response = await fetch('/api/chat/tools'); - if (response.ok) { - const data = await response.json(); - availableTools.value = data.tools; - console.log("Available tools:", availableTools.value); - } else { - console.error('获取工具列表失败'); - } + const data = await chatApi.getTools(); + availableTools.value = data.tools; + console.log("Available tools:", availableTools.value); } catch (error) { console.error('获取工具列表错误:', error); } @@ -420,6 +416,8 @@ const selectAgent = (agentId) => { // 初始化 onMounted(async () => { + // 获取默认智能体 + await fetchDefaultAgent(); // 获取智能体列表 await fetchAgents(); // 获取工具列表 @@ -429,6 +427,9 @@ onMounted(async () => { const lastSelectedAgent = localStorage.getItem('last-selected-agent'); if (lastSelectedAgent && agents.value[lastSelectedAgent]) { selectedAgentId.value = lastSelectedAgent; + } else if (defaultAgentId.value && agents.value[defaultAgentId.value]) { + // 如果有默认智能体,优先选择默认智能体 + selectedAgentId.value = defaultAgentId.value; } else if (Object.keys(agents.value).length > 0) { // 默认选择第一个智能体 selectedAgentId.value = Object.keys(agents.value)[0]; @@ -484,6 +485,21 @@ const toggleTool = (tool, checked) => { agentConfig.value.tools = agentConfig.value.tools.filter(item => item !== tool); } }; + +// 调试模式 +const toggleDebugMode = () => { + state.debug_mode = !state.debug_mode; +}; + +// 打开配置弹窗 +const openConfigModal = () => { + state.configModalVisible = true; +}; + +// 关闭配置弹窗 +const closeConfigModal = () => { + state.configModalVisible = false; +}; \ No newline at end of file diff --git a/web/src/views/SettingView.vue b/web/src/views/SettingView.vue index d2fdb82c7..f46eb9dcf 100644 --- a/web/src/views/SettingView.vue +++ b/web/src/views/SettingView.vue @@ -1,9 +1,8 @@