Compare commits
10 Commits
4d622de48d
...
79d2d93af9
Author | SHA1 | Date |
---|---|---|
|
79d2d93af9 | |
|
b1d0fa350a | |
|
39d270fbec | |
|
42eb59bdfe | |
|
72d8723dc8 | |
|
03482158ab | |
|
ec285c03d4 | |
|
98754dfdcc | |
|
577eed1f5e | |
|
1a2cdc4c60 |
|
@ -0,0 +1,587 @@
|
|||
# Sim-Search API Specification
|
||||
|
||||
This document provides a comprehensive guide for frontend developers to integrate with the Sim-Search API. The API offers intelligent research capabilities, including query processing, search execution across multiple engines, and report generation.
|
||||
|
||||
## API Base URL
|
||||
|
||||
```
|
||||
/api/v1
|
||||
```
|
||||
|
||||
## Authentication
|
||||
|
||||
The API uses OAuth2 with Bearer token authentication. All API endpoints except for authentication endpoints require a valid Bearer token.
|
||||
|
||||
### Register a New User
|
||||
|
||||
```
|
||||
POST /api/v1/auth/register
|
||||
```
|
||||
|
||||
Register a new user account.
|
||||
|
||||
**Request Body**:
|
||||
```json
|
||||
{
|
||||
"email": "user@example.com",
|
||||
"password": "password123",
|
||||
"full_name": "User Name",
|
||||
"is_active": true,
|
||||
"is_superuser": false
|
||||
}
|
||||
```
|
||||
|
||||
**Response** (200 OK):
|
||||
```json
|
||||
{
|
||||
"id": "user-uuid",
|
||||
"email": "user@example.com",
|
||||
"full_name": "User Name",
|
||||
"is_active": true,
|
||||
"is_superuser": false
|
||||
}
|
||||
```
|
||||
|
||||
### Login to Get Access Token
|
||||
|
||||
```
|
||||
POST /api/v1/auth/token
|
||||
```
|
||||
|
||||
Obtain an access token for API authentication.
|
||||
|
||||
**Request Body (form data)**:
|
||||
```
|
||||
username=user@example.com
|
||||
password=password123
|
||||
```
|
||||
|
||||
**Response** (200 OK):
|
||||
```json
|
||||
{
|
||||
"access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9...",
|
||||
"token_type": "bearer"
|
||||
}
|
||||
```
|
||||
|
||||
## Query Processing
|
||||
|
||||
### Process a Query
|
||||
|
||||
```
|
||||
POST /api/v1/query/process
|
||||
```
|
||||
|
||||
Process a search query to enhance and structure it for better search results.
|
||||
|
||||
**Headers**:
|
||||
- Authorization: Bearer {access_token}
|
||||
|
||||
**Request Body**:
|
||||
```json
|
||||
{
|
||||
"query": "What are the latest advancements in quantum computing?"
|
||||
}
|
||||
```
|
||||
|
||||
**Response** (200 OK):
|
||||
```json
|
||||
{
|
||||
"original_query": "What are the latest advancements in quantum computing?",
|
||||
"structured_query": {
|
||||
"original_query": "What are the latest advancements in quantum computing?",
|
||||
"enhanced_query": "What are the recent breakthroughs and developments in quantum computing technology, algorithms, and applications in the past 2 years?",
|
||||
"type": "exploratory",
|
||||
"intent": "research",
|
||||
"domain": "academic",
|
||||
"confidence": 0.95,
|
||||
"reasoning": "This query is asking about recent developments in a scientific field, which is typical of academic research.",
|
||||
"entities": ["quantum computing", "advancements"],
|
||||
"sub_questions": [
|
||||
{
|
||||
"sub_question": "What are the latest hardware advancements in quantum computing?",
|
||||
"aspect": "hardware",
|
||||
"priority": 0.9
|
||||
},
|
||||
{
|
||||
"sub_question": "What are the recent algorithmic breakthroughs in quantum computing?",
|
||||
"aspect": "algorithms",
|
||||
"priority": 0.8
|
||||
}
|
||||
],
|
||||
"search_queries": {
|
||||
"google": "latest advancements in quantum computing 2024",
|
||||
"scholar": "recent quantum computing breakthroughs",
|
||||
"arxiv": "quantum computing hardware algorithms"
|
||||
},
|
||||
"is_academic": true,
|
||||
"is_code": false,
|
||||
"is_current_events": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Classify a Query
|
||||
|
||||
```
|
||||
POST /api/v1/query/classify
|
||||
```
|
||||
|
||||
Classify a query by type and intent.
|
||||
|
||||
**Headers**:
|
||||
- Authorization: Bearer {access_token}
|
||||
|
||||
**Request Body**:
|
||||
```json
|
||||
{
|
||||
"query": "What are the latest advancements in quantum computing?"
|
||||
}
|
||||
```
|
||||
|
||||
**Response** (200 OK):
|
||||
```json
|
||||
{
|
||||
"original_query": "What are the latest advancements in quantum computing?",
|
||||
"structured_query": {
|
||||
"original_query": "What are the latest advancements in quantum computing?",
|
||||
"type": "exploratory",
|
||||
"domain": "academic",
|
||||
"confidence": 0.95
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Search Execution
|
||||
|
||||
### Get Available Search Engines
|
||||
|
||||
```
|
||||
GET /api/v1/search/engines
|
||||
```
|
||||
|
||||
Get a list of available search engines.
|
||||
|
||||
**Headers**:
|
||||
- Authorization: Bearer {access_token}
|
||||
|
||||
**Response** (200 OK):
|
||||
```json
|
||||
["google", "arxiv", "scholar", "news", "openalex", "core", "github", "stackexchange"]
|
||||
```
|
||||
|
||||
### Execute a Search
|
||||
|
||||
```
|
||||
POST /api/v1/search/execute
|
||||
```
|
||||
|
||||
Execute a search with the given parameters.
|
||||
|
||||
**Headers**:
|
||||
- Authorization: Bearer {access_token}
|
||||
|
||||
**Request Body**:
|
||||
```json
|
||||
{
|
||||
"structured_query": {
|
||||
"original_query": "What are the environmental impacts of electric vehicles?",
|
||||
"enhanced_query": "What are the environmental impacts of electric vehicles?",
|
||||
"type": "factual",
|
||||
"domain": "environmental"
|
||||
},
|
||||
"search_engines": ["google", "arxiv"],
|
||||
"num_results": 5,
|
||||
"timeout": 30
|
||||
}
|
||||
```
|
||||
|
||||
**Response** (200 OK):
|
||||
```json
|
||||
{
|
||||
"search_id": "search-uuid",
|
||||
"query": "What are the environmental impacts of electric vehicles?",
|
||||
"enhanced_query": "What are the environmental impacts of electric vehicles?",
|
||||
"results": {
|
||||
"google": [
|
||||
{
|
||||
"title": "Environmental Impacts of Electric Vehicles",
|
||||
"url": "https://example.com/article1",
|
||||
"snippet": "Electric vehicles have several environmental impacts including...",
|
||||
"source": "google",
|
||||
"score": 0.95
|
||||
}
|
||||
],
|
||||
"arxiv": [
|
||||
{
|
||||
"title": "Lifecycle Analysis of Electric Vehicle Environmental Impact",
|
||||
"url": "http://arxiv.org/abs/paper123",
|
||||
"pdf_url": "http://arxiv.org/pdf/paper123",
|
||||
"snippet": "This paper analyzes the complete lifecycle environmental impact of electric vehicles...",
|
||||
"source": "arxiv",
|
||||
"authors": ["Researcher Name1", "Researcher Name2"],
|
||||
"arxiv_id": "paper123",
|
||||
"categories": ["cs.CY", "eess.SY"],
|
||||
"published_date": "2023-01-15T10:30:00Z",
|
||||
"score": 0.92
|
||||
}
|
||||
]
|
||||
},
|
||||
"total_results": 2,
|
||||
"execution_time": 1.25,
|
||||
"timestamp": "2024-03-20T14:25:30Z"
|
||||
}
|
||||
```
|
||||
|
||||
### Get Search History
|
||||
|
||||
```
|
||||
GET /api/v1/search/history
|
||||
```
|
||||
|
||||
Get the user's search history.
|
||||
|
||||
**Headers**:
|
||||
- Authorization: Bearer {access_token}
|
||||
|
||||
**Query Parameters**:
|
||||
- skip (optional, default: 0): Number of records to skip
|
||||
- limit (optional, default: 100): Maximum number of records to return
|
||||
|
||||
**Response** (200 OK):
|
||||
```json
|
||||
{
|
||||
"searches": [
|
||||
{
|
||||
"id": "search-uuid",
|
||||
"query": "What are the environmental impacts of electric vehicles?",
|
||||
"enhanced_query": "What are the environmental impacts of electric vehicles?",
|
||||
"query_type": "factual",
|
||||
"engines": "google,arxiv",
|
||||
"results_count": 10,
|
||||
"created_at": "2024-03-20T14:25:30Z"
|
||||
}
|
||||
],
|
||||
"total": 1
|
||||
}
|
||||
```
|
||||
|
||||
### Get Search Results
|
||||
|
||||
```
|
||||
GET /api/v1/search/{search_id}
|
||||
```
|
||||
|
||||
Get results for a specific search.
|
||||
|
||||
**Headers**:
|
||||
- Authorization: Bearer {access_token}
|
||||
|
||||
**Path Parameters**:
|
||||
- search_id: ID of the search
|
||||
|
||||
**Response** (200 OK):
|
||||
```json
|
||||
{
|
||||
"search_id": "search-uuid",
|
||||
"query": "What are the environmental impacts of electric vehicles?",
|
||||
"enhanced_query": "What are the environmental impacts of electric vehicles?",
|
||||
"results": {
|
||||
"google": [
|
||||
{
|
||||
"title": "Environmental Impacts of Electric Vehicles",
|
||||
"url": "https://example.com/article1",
|
||||
"snippet": "Electric vehicles have several environmental impacts including...",
|
||||
"source": "google",
|
||||
"score": 0.95
|
||||
}
|
||||
],
|
||||
"arxiv": [
|
||||
{
|
||||
"title": "Lifecycle Analysis of Electric Vehicle Environmental Impact",
|
||||
"url": "http://arxiv.org/abs/paper123",
|
||||
"pdf_url": "http://arxiv.org/pdf/paper123",
|
||||
"snippet": "This paper analyzes the complete lifecycle environmental impact of electric vehicles...",
|
||||
"source": "arxiv",
|
||||
"authors": ["Researcher Name1", "Researcher Name2"],
|
||||
"arxiv_id": "paper123",
|
||||
"categories": ["cs.CY", "eess.SY"],
|
||||
"published_date": "2023-01-15T10:30:00Z",
|
||||
"score": 0.92
|
||||
}
|
||||
]
|
||||
},
|
||||
"total_results": 2,
|
||||
"execution_time": 0.0
|
||||
}
|
||||
```
|
||||
|
||||
### Delete Search
|
||||
|
||||
```
|
||||
DELETE /api/v1/search/{search_id}
|
||||
```
|
||||
|
||||
Delete a search from history.
|
||||
|
||||
**Headers**:
|
||||
- Authorization: Bearer {access_token}
|
||||
|
||||
**Path Parameters**:
|
||||
- search_id: ID of the search to delete
|
||||
|
||||
**Response** (204 No Content)
|
||||
|
||||
## Report Generation
|
||||
|
||||
### Generate a Report
|
||||
|
||||
```
|
||||
POST /api/v1/report/generate
|
||||
```
|
||||
|
||||
Generate a report from search results.
|
||||
|
||||
**Headers**:
|
||||
- Authorization: Bearer {access_token}
|
||||
|
||||
**Request Body**:
|
||||
```json
|
||||
{
|
||||
"search_id": "search-uuid",
|
||||
"query": "What are the environmental impacts of electric vehicles?",
|
||||
"detail_level": "standard",
|
||||
"query_type": "comparative",
|
||||
"model": "llama-3.1-8b-instant",
|
||||
"title": "Environmental Impacts of Electric Vehicles"
|
||||
}
|
||||
```
|
||||
|
||||
**Response** (200 OK):
|
||||
```json
|
||||
{
|
||||
"id": "report-uuid",
|
||||
"user_id": "user-uuid",
|
||||
"search_id": "search-uuid",
|
||||
"title": "Environmental Impacts of Electric Vehicles",
|
||||
"content": "Report generation in progress...",
|
||||
"detail_level": "standard",
|
||||
"query_type": "comparative",
|
||||
"model_used": "llama-3.1-8b-instant",
|
||||
"created_at": "2024-03-20T14:30:00Z",
|
||||
"updated_at": "2024-03-20T14:30:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
### Get Report Generation Progress
|
||||
|
||||
```
|
||||
GET /api/v1/report/{report_id}/progress
|
||||
```
|
||||
|
||||
Get the progress of a report generation.
|
||||
|
||||
**Headers**:
|
||||
- Authorization: Bearer {access_token}
|
||||
|
||||
**Path Parameters**:
|
||||
- report_id: ID of the report
|
||||
|
||||
**Response** (200 OK):
|
||||
```json
|
||||
{
|
||||
"report_id": "report-uuid",
|
||||
"progress": 0.75,
|
||||
"status": "Processing chunk 3/4...",
|
||||
"current_chunk": 3,
|
||||
"total_chunks": 4,
|
||||
"current_report": "The environmental impacts of electric vehicles include..."
|
||||
}
|
||||
```
|
||||
|
||||
### Get Report List
|
||||
|
||||
```
|
||||
GET /api/v1/report/list
|
||||
```
|
||||
|
||||
Get a list of user's reports.
|
||||
|
||||
**Headers**:
|
||||
- Authorization: Bearer {access_token}
|
||||
|
||||
**Query Parameters**:
|
||||
- skip (optional, default: 0): Number of records to skip
|
||||
- limit (optional, default: 100): Maximum number of records to return
|
||||
|
||||
**Response** (200 OK):
|
||||
```json
|
||||
{
|
||||
"reports": [
|
||||
{
|
||||
"id": "report-uuid",
|
||||
"user_id": "user-uuid",
|
||||
"search_id": "search-uuid",
|
||||
"title": "Environmental Impacts of Electric Vehicles",
|
||||
"content": "# Environmental Impacts of Electric Vehicles\n\n## Introduction\n\nElectric vehicles (EVs) have gained popularity...",
|
||||
"detail_level": "standard",
|
||||
"query_type": "comparative",
|
||||
"model_used": "llama-3.1-8b-instant",
|
||||
"created_at": "2024-03-20T14:30:00Z",
|
||||
"updated_at": "2024-03-20T14:35:00Z"
|
||||
}
|
||||
],
|
||||
"total": 1
|
||||
}
|
||||
```
|
||||
|
||||
### Get Report
|
||||
|
||||
```
|
||||
GET /api/v1/report/{report_id}
|
||||
```
|
||||
|
||||
Get a specific report.
|
||||
|
||||
**Headers**:
|
||||
- Authorization: Bearer {access_token}
|
||||
|
||||
**Path Parameters**:
|
||||
- report_id: ID of the report
|
||||
|
||||
**Response** (200 OK):
|
||||
```json
|
||||
{
|
||||
"id": "report-uuid",
|
||||
"user_id": "user-uuid",
|
||||
"search_id": "search-uuid",
|
||||
"title": "Environmental Impacts of Electric Vehicles",
|
||||
"content": "# Environmental Impacts of Electric Vehicles\n\n## Introduction\n\nElectric vehicles (EVs) have gained popularity...",
|
||||
"detail_level": "standard",
|
||||
"query_type": "comparative",
|
||||
"model_used": "llama-3.1-8b-instant",
|
||||
"created_at": "2024-03-20T14:30:00Z",
|
||||
"updated_at": "2024-03-20T14:35:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
### Download Report
|
||||
|
||||
```
|
||||
GET /api/v1/report/{report_id}/download
|
||||
```
|
||||
|
||||
Download a report in the specified format.
|
||||
|
||||
**Headers**:
|
||||
- Authorization: Bearer {access_token}
|
||||
|
||||
**Path Parameters**:
|
||||
- report_id: ID of the report
|
||||
|
||||
**Query Parameters**:
|
||||
- format (optional, default: "markdown"): Format of the report (markdown, html, pdf)
|
||||
|
||||
**Response** (200 OK):
|
||||
- Content-Type: application/octet-stream
|
||||
- Content-Disposition: attachment; filename="report_{report_id}.{format}"
|
||||
- Binary file content
|
||||
|
||||
### Delete Report
|
||||
|
||||
```
|
||||
DELETE /api/v1/report/{report_id}
|
||||
```
|
||||
|
||||
Delete a report.
|
||||
|
||||
**Headers**:
|
||||
- Authorization: Bearer {access_token}
|
||||
|
||||
**Path Parameters**:
|
||||
- report_id: ID of the report to delete
|
||||
|
||||
**Response** (204 No Content)
|
||||
|
||||
## Error Handling
|
||||
|
||||
The API returns standard HTTP status codes to indicate the success or failure of a request.
|
||||
|
||||
### Common Error Codes
|
||||
|
||||
- 400 Bad Request: The request was invalid or cannot be served
|
||||
- 401 Unauthorized: Authentication is required or has failed
|
||||
- 403 Forbidden: The authenticated user doesn't have the necessary permissions
|
||||
- 404 Not Found: The requested resource was not found
|
||||
- 422 Unprocessable Entity: The request data failed validation
|
||||
- 500 Internal Server Error: An error occurred on the server
|
||||
|
||||
### Error Response Format
|
||||
|
||||
```json
|
||||
{
|
||||
"detail": "Error message explaining what went wrong"
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices for Frontend Integration
|
||||
|
||||
1. **Authentication Flow**:
|
||||
- Implement a login form that sends credentials to `/api/v1/auth/token`
|
||||
- Store the received token securely (HTTP-only cookies or secure storage)
|
||||
- Include the token in the Authorization header for all subsequent requests
|
||||
- Implement token expiration handling and refresh mechanism
|
||||
|
||||
2. **Query Processing Workflow**:
|
||||
- Allow users to enter natural language queries
|
||||
- Use the `/api/v1/query/process` endpoint to enhance the query
|
||||
- Display the enhanced query to the user for confirmation
|
||||
|
||||
3. **Search Execution**:
|
||||
- Use the processed query for search execution
|
||||
- Allow users to select which search engines to use
|
||||
- Implement a loading state while waiting for search results
|
||||
- Display search results grouped by search engine
|
||||
|
||||
4. **Report Generation**:
|
||||
- Allow users to generate reports from search results
|
||||
- Provide options for detail level and report type
|
||||
- Implement progress tracking using the progress endpoint
|
||||
- Allow users to download reports in different formats
|
||||
|
||||
5. **Error Handling**:
|
||||
- Implement proper error handling for API responses
|
||||
- Display meaningful error messages to users
|
||||
- Implement retry mechanisms for transient errors
|
||||
|
||||
## Available Search Engines
|
||||
|
||||
- **google**: General web search
|
||||
- **arxiv**: Academic papers from arXiv
|
||||
- **scholar**: Academic papers from various sources
|
||||
- **news**: News articles
|
||||
- **openalex**: Open access academic content
|
||||
- **core**: Open access research papers
|
||||
- **github**: Code repositories
|
||||
- **stackexchange**: Q&A from Stack Exchange network
|
||||
|
||||
## Report Detail Levels
|
||||
|
||||
- **brief**: Short summary (default model: llama-3.1-8b-instant)
|
||||
- **standard**: Comprehensive overview (default model: llama-3.1-8b-instant)
|
||||
- **detailed**: In-depth analysis (default model: llama-3.3-70b-versatile)
|
||||
- **comprehensive**: Extensive research report (default model: llama-3.3-70b-versatile)
|
||||
|
||||
## Query Types
|
||||
|
||||
- **factual**: Seeking facts or information
|
||||
- **comparative**: Comparing multiple items or concepts
|
||||
- **exploratory**: Open-ended exploration of a topic
|
||||
- **procedural**: How to do something
|
||||
- **causal**: Seeking cause-effect relationships
|
||||
|
||||
## Models
|
||||
|
||||
- **llama-3.1-8b-instant**: Fast, lightweight model
|
||||
- **llama-3.3-70b-versatile**: High-quality, comprehensive model
|
||||
- **Other models may be available based on server configuration**
|
|
@ -82,6 +82,48 @@ project/
|
|||
│ └── gradio_interface.py # Gradio-based web interface
|
||||
├── scripts/ # Scripts
|
||||
│ └── query_to_report.py # Script for generating reports from queries
|
||||
├── sim-search-api/ # FastAPI backend
|
||||
│ ├── app/
|
||||
│ │ ├── api/
|
||||
│ │ │ ├── routes/
|
||||
│ │ │ │ ├── __init__.py
|
||||
│ │ │ │ ├── auth.py # Authentication routes
|
||||
│ │ │ │ ├── query.py # Query processing routes
|
||||
│ │ │ │ ├── search.py # Search execution routes
|
||||
│ │ │ │ └── report.py # Report generation routes
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ └── dependencies.py # API dependencies (auth, rate limiting)
|
||||
│ │ ├── core/
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ ├── config.py # API configuration
|
||||
│ │ │ └── security.py # Security utilities
|
||||
│ │ ├── db/
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ ├── session.py # Database session
|
||||
│ │ │ └── models.py # Database models for reports, searches
|
||||
│ │ ├── schemas/
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ ├── token.py # Token schemas
|
||||
│ │ │ ├── user.py # User schemas
|
||||
│ │ │ ├── query.py # Query schemas
|
||||
│ │ │ ├── search.py # Search result schemas
|
||||
│ │ │ └── report.py # Report schemas
|
||||
│ │ ├── services/
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ ├── query_service.py # Query processing service
|
||||
│ │ │ ├── search_service.py # Search execution service
|
||||
│ │ │ └── report_service.py # Report generation service
|
||||
│ │ └── main.py # FastAPI application
|
||||
│ ├── alembic/ # Database migrations
|
||||
│ │ ├── versions/
|
||||
│ │ │ └── 001_initial_migration.py # Initial migration
|
||||
│ │ ├── env.py # Alembic environment
|
||||
│ │ └── script.py.mako # Alembic script template
|
||||
│ ├── .env.example # Environment variables template
|
||||
│ ├── alembic.ini # Alembic configuration
|
||||
│ ├── requirements.txt # API dependencies
|
||||
│ ├── run.py # Script to run the API
|
||||
│ └── README.md # API documentation
|
||||
├── run_ui.py # Script to run the UI
|
||||
└── requirements.txt # Project dependencies
|
||||
```
|
||||
|
@ -263,8 +305,139 @@ The `progressive_report_synthesis` module provides functionality to synthesize r
|
|||
|
||||
- `get_progressive_report_synthesizer(model_name)`: Factory function to get a singleton instance
|
||||
|
||||
### FastAPI Backend Module
|
||||
|
||||
The `sim-search-api` module provides a RESTful API for the sim-search system, allowing for query processing, search execution, and report generation through HTTP endpoints.
|
||||
|
||||
### Files
|
||||
|
||||
- `app/`: Main application directory
|
||||
- `api/`: API routes and dependencies
|
||||
- `routes/`: API route handlers
|
||||
- `auth.py`: Authentication routes
|
||||
- `query.py`: Query processing routes
|
||||
- `search.py`: Search execution routes
|
||||
- `report.py`: Report generation routes
|
||||
- `dependencies.py`: API dependencies (auth, rate limiting)
|
||||
- `core/`: Core functionality
|
||||
- `config.py`: API configuration
|
||||
- `security.py`: Security utilities
|
||||
- `db/`: Database models and session management
|
||||
- `models.py`: Database models for users, searches, and reports
|
||||
- `session.py`: Database session management
|
||||
- `schemas/`: Pydantic schemas for request/response validation
|
||||
- `token.py`: Token schemas
|
||||
- `user.py`: User schemas
|
||||
- `query.py`: Query schemas
|
||||
- `search.py`: Search result schemas
|
||||
- `report.py`: Report schemas
|
||||
- `services/`: Service layer for business logic
|
||||
- `query_service.py`: Query processing service
|
||||
- `search_service.py`: Search execution service
|
||||
- `report_service.py`: Report generation service
|
||||
- `main.py`: FastAPI application entry point
|
||||
- `alembic/`: Database migrations
|
||||
- `versions/`: Migration versions
|
||||
- `001_initial_migration.py`: Initial migration
|
||||
- `env.py`: Alembic environment
|
||||
- `script.py.mako`: Alembic script template
|
||||
- `alembic.ini`: Alembic configuration
|
||||
- `requirements.txt`: API dependencies
|
||||
- `run.py`: Script to run the API
|
||||
- `.env.example`: Environment variables template
|
||||
- `README.md`: API documentation
|
||||
|
||||
### Classes
|
||||
|
||||
- `app.db.models.User`: User model for authentication
|
||||
- `id` (str): User ID
|
||||
- `email` (str): User email
|
||||
- `hashed_password` (str): Hashed password
|
||||
- `full_name` (str): User's full name
|
||||
- `is_active` (bool): Whether the user is active
|
||||
- `is_superuser` (bool): Whether the user is a superuser
|
||||
|
||||
- `app.db.models.Search`: Search model for storing search results
|
||||
- `id` (str): Search ID
|
||||
- `user_id` (str): User ID
|
||||
- `query` (str): Original query
|
||||
- `enhanced_query` (str): Enhanced query
|
||||
- `query_type` (str): Query type
|
||||
- `engines` (str): Search engines used
|
||||
- `results_count` (int): Number of results
|
||||
- `results` (JSON): Search results
|
||||
- `created_at` (datetime): Creation timestamp
|
||||
|
||||
- `app.db.models.Report`: Report model for storing generated reports
|
||||
- `id` (str): Report ID
|
||||
- `user_id` (str): User ID
|
||||
- `search_id` (str): Search ID
|
||||
- `title` (str): Report title
|
||||
- `content` (str): Report content
|
||||
- `detail_level` (str): Detail level
|
||||
- `query_type` (str): Query type
|
||||
- `model_used` (str): Model used for generation
|
||||
- `created_at` (datetime): Creation timestamp
|
||||
- `updated_at` (datetime): Update timestamp
|
||||
|
||||
- `app.services.QueryService`: Service for query processing
|
||||
- `process_query(query)`: Processes a query
|
||||
- `classify_query(query)`: Classifies a query
|
||||
|
||||
- `app.services.SearchService`: Service for search execution
|
||||
- `execute_search(structured_query, search_engines, num_results, timeout, user_id, db)`: Executes a search
|
||||
- `get_available_search_engines()`: Gets available search engines
|
||||
- `get_search_results(search)`: Gets results for a specific search
|
||||
|
||||
- `app.services.ReportService`: Service for report generation
|
||||
- `generate_report_background(report_id, report_in, search, db, progress_dict)`: Generates a report in the background
|
||||
- `generate_report_file(report, format)`: Generates a report file in the specified format
|
||||
|
||||
## Recent Updates
|
||||
|
||||
### 2025-03-20: FastAPI Backend Implementation
|
||||
|
||||
1. **FastAPI Application Structure**:
|
||||
- Created a new directory `sim-search-api` for the FastAPI application
|
||||
- Set up project structure with API routes, core functionality, database models, schemas, and services
|
||||
- Implemented a layered architecture with API, service, and data layers
|
||||
- Added proper `__init__.py` files to make all directories proper Python packages
|
||||
|
||||
2. **API Routes Implementation**:
|
||||
- Created authentication routes for user registration and token generation
|
||||
- Implemented query processing routes for query enhancement and classification
|
||||
- Added search execution routes for executing searches and managing search history
|
||||
- Created report generation routes for generating and managing reports
|
||||
- Implemented proper error handling and validation for all routes
|
||||
|
||||
3. **Service Layer Implementation**:
|
||||
- Created `QueryService` to bridge between API and existing query processing functionality
|
||||
- Implemented `SearchService` for search execution and result management
|
||||
- Added `ReportService` for report generation and management
|
||||
- Ensured proper integration with existing sim-search functionality
|
||||
- Implemented asynchronous operation for all services
|
||||
|
||||
4. **Database Setup**:
|
||||
- Created SQLAlchemy models for users, searches, and reports
|
||||
- Implemented database session management
|
||||
- Set up Alembic for database migrations
|
||||
- Created initial migration script to create all tables
|
||||
- Added proper relationships between models
|
||||
|
||||
5. **Authentication and Security**:
|
||||
- Implemented JWT-based authentication
|
||||
- Added password hashing and verification
|
||||
- Created token generation and validation
|
||||
- Implemented user registration and login
|
||||
- Added proper authorization for protected routes
|
||||
|
||||
6. **Documentation and Configuration**:
|
||||
- Created comprehensive API documentation
|
||||
- Added OpenAPI documentation endpoints
|
||||
- Implemented environment variable configuration
|
||||
- Created a README with setup and usage instructions
|
||||
- Added example environment variables file
|
||||
|
||||
### 2025-03-12: Progressive Report Generation Implementation
|
||||
|
||||
1. **Progressive Report Synthesis Module**:
|
||||
|
|
|
@ -1,13 +1,55 @@
|
|||
# Current Focus: UI Bug Fixes, Project Directory Reorganization, and Embedding Usage
|
||||
# Current Focus: FastAPI Implementation, API Testing, and Progressive Report Generation
|
||||
|
||||
## Active Work
|
||||
|
||||
### FastAPI Implementation
|
||||
- ✅ Created directory structure for FastAPI application following the implementation plan
|
||||
- ✅ Implemented core FastAPI application with configuration and security
|
||||
- ✅ Created database models for users, searches, and reports
|
||||
- ✅ Implemented API routes for authentication, query processing, search execution, and report generation
|
||||
- ✅ Created service layer to bridge between API and existing sim-search functionality
|
||||
- ✅ Set up database migrations with Alembic
|
||||
- ✅ Added comprehensive documentation for the API
|
||||
- ✅ Created environment variable configuration
|
||||
- ✅ Implemented JWT-based authentication
|
||||
- ✅ Added OpenAPI documentation endpoints
|
||||
|
||||
### API Testing
|
||||
- ✅ Created comprehensive test suite for the API using pytest
|
||||
- ✅ Implemented test fixtures for database initialization and user authentication
|
||||
- ✅ Added tests for authentication, query processing, search execution, and report generation
|
||||
- ✅ Created a test runner script with options for verbosity, coverage reporting, and test selection
|
||||
- ✅ Implemented a manual testing script using curl commands
|
||||
- ✅ Added test documentation with instructions for running tests and troubleshooting
|
||||
- ✅ Set up test database isolation to avoid affecting production data
|
||||
- ✅ Fixed deprecated Pydantic features to ensure tests run correctly
|
||||
- ✅ Replaced dict() with model_dump() in API routes
|
||||
- ✅ Updated orm_mode to from_attributes in schema classes
|
||||
- ✅ Changed schema_extra to json_schema_extra in schema classes
|
||||
|
||||
### LLM-Based Query Domain Classification
|
||||
- ✅ Implemented LLM-based query domain classification to replace keyword-based approach
|
||||
- ✅ Added `classify_query_domain` method to `LLMInterface` class
|
||||
- ✅ Created `_structure_query_with_llm` method in `QueryProcessor` to use LLM classification results
|
||||
- ✅ Added fallback to keyword-based classification for resilience
|
||||
- ✅ Enhanced structured query with domain, confidence, and reasoning fields
|
||||
- ✅ Added comprehensive test script to verify functionality
|
||||
- ✅ Added detailed documentation about the new implementation
|
||||
- ✅ Updated configuration to support the new classification method
|
||||
- ✅ Improved logging for better monitoring of classification results
|
||||
|
||||
### UI Bug Fixes
|
||||
- ✅ Fixed AttributeError in report generation progress callback
|
||||
- ✅ Updated UI progress callback to use direct value assignment instead of update method
|
||||
- ✅ Enhanced progress callback to use Gradio's built-in progress tracking mechanism for better UI updates during async operations
|
||||
- ✅ Consolidated redundant progress indicators in the UI to use only Gradio's built-in progress tracking
|
||||
- ✅ Committed changes with message "Enhanced UI progress callback to use Gradio's built-in progress tracking mechanism for better real-time updates during report generation"
|
||||
- ✅ Fixed model selection issue in report generation to ensure the model selected in the UI is properly used throughout the report generation process
|
||||
- ✅ Fixed model provider selection to correctly use the provider specified in the config.yaml file (e.g., ensuring Gemini models use the Gemini provider)
|
||||
- ✅ Added detailed logging for model and provider selection to aid in debugging
|
||||
- ✅ Implemented comprehensive tests for provider selection stability across multiple initializations, model switches, and configuration changes
|
||||
- ✅ Enhanced provider selection stability tests to include fallback mechanisms, edge cases with invalid providers, and provider selection consistency between singleton and new instances
|
||||
- ✅ Added test for provider selection stability after config reload
|
||||
- ✅ Committed changes with message "Enhanced provider selection stability tests with additional scenarios and edge cases"
|
||||
|
||||
### Project Directory Reorganization
|
||||
- ✅ Reorganized project directory structure for better maintainability
|
||||
|
@ -27,14 +69,37 @@
|
|||
- ✅ Verified that the UI works correctly with the new directory structure
|
||||
- ✅ Confirmed that all imports are working properly with the new structure
|
||||
|
||||
## Repository Cleanup
|
||||
- Reorganized test files into dedicated directories under `tests/`
|
||||
- Created `examples/` directory for sample data
|
||||
- Moved utility scripts to `utils/`
|
||||
- Committed changes with message 'Clean up repository: Remove unused test files and add new test directories'
|
||||
|
||||
## Recent Changes
|
||||
|
||||
### API Testing Fixes
|
||||
- Fixed deprecated Pydantic features to ensure tests run correctly
|
||||
- Replaced dict() with model_dump() in API routes
|
||||
- Updated orm_mode to from_attributes in schema classes
|
||||
- Changed schema_extra to json_schema_extra in schema classes
|
||||
- Made test scripts executable for easier running
|
||||
- Committed changes with message "Fix deprecated Pydantic features: replace dict() with model_dump(), orm_mode with from_attributes, and schema_extra with json_schema_extra"
|
||||
|
||||
### API Testing Implementation
|
||||
- Created comprehensive test suite for the API using pytest
|
||||
- Implemented test fixtures for database initialization and user authentication
|
||||
- Added tests for authentication, query processing, search execution, and report generation
|
||||
- Created a test runner script with options for verbosity, coverage reporting, and test selection
|
||||
- Implemented a manual testing script using curl commands
|
||||
- Added test documentation with instructions for running tests and troubleshooting
|
||||
- Set up test database isolation to avoid affecting production data
|
||||
|
||||
### FastAPI Implementation
|
||||
- Created a new `sim-search-api` directory for the FastAPI application
|
||||
- Implemented a layered architecture with API, service, and data layers
|
||||
- Created database models for users, searches, and reports
|
||||
- Implemented API routes for all functionality
|
||||
- Created service layer to bridge between API and existing sim-search functionality
|
||||
- Set up database migrations with Alembic
|
||||
- Added JWT-based authentication
|
||||
- Created comprehensive documentation for the API
|
||||
- Added environment variable configuration
|
||||
- Implemented OpenAPI documentation endpoints
|
||||
|
||||
### Directory Structure Reorganization
|
||||
- Created a dedicated `utils/` directory for utility scripts
|
||||
- Moved `jina_similarity.py` to `utils/`
|
||||
|
@ -48,12 +113,6 @@
|
|||
- Added a dedicated `scripts/` directory for utility scripts
|
||||
- Moved `query_to_report.py` to `scripts/`
|
||||
|
||||
### Pipeline Verification
|
||||
- Verified that the pipeline functions correctly after reorganization
|
||||
- Confirmed that the `JinaSimilarity` class in `utils/jina_similarity.py` is properly used for embeddings
|
||||
- Tested the reranking functionality with the `JinaReranker` class
|
||||
- Checked that the report generation process works with the new structure
|
||||
|
||||
### Query Type Selection in Gradio UI
|
||||
- ✅ Added a dropdown menu for query type selection in the "Generate Report" tab
|
||||
- ✅ Included options for "auto-detect", "factual", "exploratory", and "comparative"
|
||||
|
@ -68,13 +127,16 @@
|
|||
|
||||
## Next Steps
|
||||
|
||||
1. Run comprehensive tests to ensure all functionality works with the new directory structure
|
||||
2. Update any remaining documentation to reflect the new directory structure
|
||||
3. Consider moving the remaining test files in the root of the `tests/` directory to appropriate subdirectories
|
||||
4. Review import statements throughout the codebase to ensure they follow the new structure
|
||||
5. Add more comprehensive documentation about the directory structure
|
||||
6. Consider creating a development guide for new contributors
|
||||
7. Implement automated tests to verify the directory structure remains consistent
|
||||
1. Continue testing the API to ensure all endpoints work correctly
|
||||
2. Fix any remaining issues found during testing
|
||||
3. Add more specific tests for edge cases and error handling
|
||||
4. Integrate the tests into a CI/CD pipeline
|
||||
5. Create a React frontend to consume the FastAPI backend
|
||||
6. Implement user management in the frontend
|
||||
7. Add search history and report management in the frontend
|
||||
8. Implement real-time progress tracking for report generation in the frontend
|
||||
9. Add visualization components for reports in the frontend
|
||||
10. Consider adding more API endpoints for additional functionality
|
||||
|
||||
### Future Enhancements
|
||||
|
||||
|
@ -116,117 +178,32 @@
|
|||
|
||||
### Current Tasks
|
||||
|
||||
1. **Report Generation Module Implementation (Phase 4)**:
|
||||
1. **API Testing**:
|
||||
- Continue testing the API to ensure all endpoints work correctly
|
||||
- Fix any remaining issues found during testing
|
||||
- Add more specific tests for edge cases and error handling
|
||||
- Integrate the tests into a CI/CD pipeline
|
||||
|
||||
2. **Report Generation Module Implementation (Phase 4)**:
|
||||
- Implementing support for alternative models with larger context windows
|
||||
- Implementing progressive report generation for very large research tasks
|
||||
- Creating visualization components for data mentioned in reports
|
||||
- Adding interactive elements to the generated reports
|
||||
- Implementing report versioning and comparison
|
||||
|
||||
2. **Integration with UI**:
|
||||
3. **Integration with UI**:
|
||||
- ✅ Adding report generation options to the UI
|
||||
- ✅ Implementing progress indicators for document scraping and report generation
|
||||
- ✅ Adding query type selection to the UI
|
||||
- Creating visualization components for generated reports
|
||||
- Adding options to customize report generation parameters
|
||||
|
||||
3. **Performance Optimization**:
|
||||
4. **Performance Optimization**:
|
||||
- Optimizing token usage for more efficient LLM utilization
|
||||
- Implementing caching strategies for document scraping and LLM calls
|
||||
- Parallelizing document scraping and processing
|
||||
- Exploring parallel processing for the map phase of report synthesis
|
||||
|
||||
### Recent Progress
|
||||
|
||||
1. **Report Templates Implementation**:
|
||||
- ✅ Created a dedicated `report_templates.py` module with a comprehensive template system
|
||||
- ✅ Implemented `QueryType` enum for categorizing queries (FACTUAL, EXPLORATORY, COMPARATIVE, CODE)
|
||||
- ✅ Created `DetailLevel` enum for different report detail levels (BRIEF, STANDARD, DETAILED, COMPREHENSIVE)
|
||||
- ✅ Designed a `ReportTemplate` class with validation for required sections
|
||||
- ✅ Implemented a `ReportTemplateManager` to manage and retrieve templates
|
||||
- ✅ Created 16 different templates (4 query types × 4 detail levels)
|
||||
- ✅ Added testing with `test_report_templates.py` and `test_brief_report.py`
|
||||
- ✅ Updated memory bank documentation with template system details
|
||||
|
||||
2. **Testing and Validation of Report Templates**:
|
||||
- ✅ Fixed template retrieval issues in the report synthesis module
|
||||
- ✅ Successfully tested all detail levels (brief, standard, detailed, comprehensive) with factual queries
|
||||
- ✅ Successfully tested all detail levels with exploratory queries
|
||||
- ✅ Successfully tested all detail levels with comparative queries
|
||||
- ✅ Improved error handling in template retrieval with fallback to standard templates
|
||||
- ✅ Added better logging for template retrieval process
|
||||
|
||||
3. **UI Enhancements**:
|
||||
- ✅ Added progress tracking for report generation
|
||||
- ✅ Added query type selection dropdown
|
||||
- ✅ Added documentation for query types and detail levels
|
||||
- ✅ Improved error handling in the UI
|
||||
|
||||
### Next Steps
|
||||
|
||||
1. **Further Refinement of Report Templates**:
|
||||
- Conduct additional testing with real-world queries and document sets
|
||||
- Compare the analytical depth and quality of reports generated with different detail levels
|
||||
- Gather user feedback on the improved reports at different detail levels
|
||||
- Further refine the detail level configurations based on testing and feedback
|
||||
- Integrate the template system with the UI to allow users to select detail levels
|
||||
- Add more specialized templates for specific research domains
|
||||
- Implement template customization options for users
|
||||
|
||||
2. **Progressive Report Generation Implementation**:
|
||||
- ✅ Implemented progressive report generation for comprehensive detail level reports
|
||||
- ✅ Created a hybrid system that uses standard map-reduce for brief/standard/detailed levels and progressive generation for comprehensive level
|
||||
- ✅ Added support for different models with adaptive batch sizing
|
||||
- ✅ Implemented progress tracking and callback mechanism
|
||||
- ✅ Created comprehensive test suite for progressive report generation
|
||||
- ⏳ Add UI controls to monitor and control the progressive generation process
|
||||
|
||||
#### Implementation Details for Progressive Report Generation
|
||||
|
||||
**Phase 1: Core Implementation (Completed)**
|
||||
- ✅ Created a new `ProgressiveReportSynthesizer` class extending from `ReportSynthesizer`
|
||||
- ✅ Implemented chunk prioritization algorithm based on relevance scores
|
||||
- ✅ Developed the iterative refinement process with specialized prompts
|
||||
- ✅ Added state management to track report versions and processed chunks
|
||||
- ✅ Implemented termination conditions (all chunks processed, diminishing returns, user intervention)
|
||||
|
||||
**Phase 2: Model Flexibility (Completed)**
|
||||
- ✅ Modified the implementation to support different models beyond Gemini
|
||||
- ✅ Created model-specific configurations for progressive generation
|
||||
- ✅ Implemented adaptive batch sizing based on model context window
|
||||
- ✅ Added fallback mechanisms for when context windows are exceeded
|
||||
|
||||
**Phase 3: UI Integration (In Progress)**
|
||||
- ✅ Added progress tracking callback mechanism
|
||||
- ⏳ Implement controls to pause, resume, or terminate the process
|
||||
- ⏳ Create a preview mode to see the current report state
|
||||
- ⏳ Add options to compare different versions of the report
|
||||
|
||||
**Phase 4: Testing and Optimization (Completed)**
|
||||
- ✅ Created test script for progressive report generation
|
||||
- ✅ Added comparison functionality between progressive and standard approaches
|
||||
- ✅ Implemented optimization for token usage and processing efficiency
|
||||
- ✅ Fine-tuned prompts and parameters based on testing results
|
||||
|
||||
3. **Query Type Selection Enhancement**:
|
||||
- ✅ Added query type selection dropdown to the UI
|
||||
- ✅ Implemented handling of user-selected query types in the report generation process
|
||||
- ✅ Added documentation to help users understand when to use each query type
|
||||
- ✅ Added CODE as a new query type with specialized templates at all detail levels
|
||||
- ✅ Implemented code query detection with language, framework, and pattern recognition
|
||||
- ✅ Added GitHub and StackExchange search handlers for code-related queries
|
||||
- ⏳ Test the query type selection with various queries to ensure it works correctly
|
||||
- ⏳ Gather user feedback on the usefulness of manual query type selection
|
||||
- ⏳ Consider adding more specialized templates for specific query types
|
||||
- ⏳ Explore adding query type detection confidence scores to help users decide when to override
|
||||
- ⏳ Add examples of each query type to help users understand the differences
|
||||
|
||||
4. **Visualization Components**:
|
||||
- Identify common data types in reports that would benefit from visualization
|
||||
- Design and implement visualization components for these data types
|
||||
- Integrate visualization components into the report generation process
|
||||
- Consider how visualizations can be incorporated into progressive reports
|
||||
|
||||
### Technical Notes
|
||||
|
||||
- Using Groq's Llama 3.3 70B Versatile model for detailed and comprehensive report synthesis
|
||||
|
@ -253,3 +230,18 @@
|
|||
- Created code detection based on programming languages, frameworks, and patterns
|
||||
- Designed specialized report templates for code content with syntax highlighting
|
||||
- Enhanced result ranking to prioritize code-related sources for programming queries
|
||||
- Implemented FastAPI backend for the sim-search system:
|
||||
- Created a layered architecture with API, service, and data layers
|
||||
- Implemented JWT-based authentication
|
||||
- Created database models for users, searches, and reports
|
||||
- Added service layer to bridge between API and existing sim-search functionality
|
||||
- Set up database migrations with Alembic
|
||||
- Added comprehensive documentation for the API
|
||||
- Implemented OpenAPI documentation endpoints
|
||||
- Created comprehensive testing framework for the API:
|
||||
- Implemented automated tests with pytest for all API endpoints
|
||||
- Created a test runner script with options for verbosity and coverage reporting
|
||||
- Implemented a manual testing script using curl commands
|
||||
- Added test documentation with instructions for running tests and troubleshooting
|
||||
- Set up test database isolation to avoid affecting production data
|
||||
- Fixed deprecated Pydantic features to ensure tests run correctly
|
||||
|
|
|
@ -439,3 +439,59 @@ Implemented and tested successfully with both sample data and real URLs.
|
|||
- Added duplicate URL fields in the context to ensure URLs are captured
|
||||
- Updated the reference generation prompt to explicitly request URLs
|
||||
- Added a separate reference generation step to handle truncated references
|
||||
|
||||
## 2025-03-18: LLM-Based Query Classification Implementation
|
||||
|
||||
### Context
|
||||
The project was using a keyword-based approach to classify queries into different domains (academic, code, current events). This approach had several limitations:
|
||||
- Reliance on static keyword lists that needed constant maintenance
|
||||
- Inability to understand the semantic meaning of queries
|
||||
- False classifications for ambiguous queries or those containing keywords with multiple meanings
|
||||
- Difficulty handling emerging topics without updating keyword lists
|
||||
|
||||
### Decision
|
||||
1. Replace the keyword-based query classification with an LLM-based approach:
|
||||
- Implement a new `classify_query_domain` method in the `LLMInterface` class
|
||||
- Create a new query structuring method that uses the LLM classification results
|
||||
- Retain the keyword-based method as a fallback
|
||||
- Add confidence scores and reasoning to the classification results
|
||||
|
||||
2. Enhance the structured query format:
|
||||
- Add primary domain and confidence
|
||||
- Include secondary domains with confidence scores
|
||||
- Add classification reasoning
|
||||
- Maintain backward compatibility with existing search executor
|
||||
|
||||
3. Use a 0.3 confidence threshold for secondary domains:
|
||||
- Set domain flags (is_academic, is_code, is_current_events) based on primary domain
|
||||
- Also set flags for secondary domains with confidence scores above 0.3
|
||||
|
||||
### Rationale
|
||||
- LLM-based approach provides better semantic understanding of queries
|
||||
- Multi-domain classification with confidence scores handles complex queries better
|
||||
- Self-explaining classifications with reasoning aids debugging and transparency
|
||||
- The approach automatically adapts to new topics without code changes
|
||||
- Retaining keyword-based fallback ensures system resilience
|
||||
|
||||
### Alternatives Considered
|
||||
1. Expanding the keyword lists:
|
||||
- Would still lack semantic understanding
|
||||
- Increasing maintenance burden
|
||||
- False positives would still occur
|
||||
|
||||
2. Using embedding similarity to predefined domain descriptions:
|
||||
- Potentially more computationally expensive
|
||||
- Less explainable than the LLM's reasoning
|
||||
- Would require managing embedding models
|
||||
|
||||
3. Creating a custom classifier:
|
||||
- Would require labeled training data
|
||||
- More development effort
|
||||
- Less flexible than the LLM approach
|
||||
|
||||
### Impact
|
||||
- More accurate query classification, especially for ambiguous or multi-domain queries
|
||||
- Reduction in maintenance overhead for keyword lists
|
||||
- Better search engine selection based on query domains
|
||||
- Improved report generation due to more accurate query understanding
|
||||
- Enhanced debugging capabilities with classification reasoning
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
|
||||
## 2025-03-18: LLM-Based Query Classification Implementation
|
||||
|
||||
### Context
|
||||
The project was using a keyword-based approach to classify queries into different domains (academic, code, current events). This approach had several limitations:
|
||||
- Reliance on static keyword lists that needed constant maintenance
|
||||
- Inability to understand the semantic meaning of queries
|
||||
- False classifications for ambiguous queries or those containing keywords with multiple meanings
|
||||
- Difficulty handling emerging topics without updating keyword lists
|
||||
|
||||
### Decision
|
||||
1. Replace the keyword-based query classification with an LLM-based approach:
|
||||
- Implement a new `classify_query_domain` method in the `LLMInterface` class
|
||||
- Create a new query structuring method that uses the LLM classification results
|
||||
- Retain the keyword-based method as a fallback
|
||||
- Add confidence scores and reasoning to the classification results
|
||||
|
||||
2. Enhance the structured query format:
|
||||
- Add primary domain and confidence
|
||||
- Include secondary domains with confidence scores
|
||||
- Add classification reasoning
|
||||
- Maintain backward compatibility with existing search executor
|
||||
|
||||
3. Use a 0.3 confidence threshold for secondary domains:
|
||||
- Set domain flags (is_academic, is_code, is_current_events) based on primary domain
|
||||
- Also set flags for secondary domains with confidence scores above 0.3
|
||||
|
||||
### Rationale
|
||||
- LLM-based approach provides better semantic understanding of queries
|
||||
- Multi-domain classification with confidence scores handles complex queries better
|
||||
- Self-explaining classifications with reasoning aids debugging and transparency
|
||||
- The approach automatically adapts to new topics without code changes
|
||||
- Retaining keyword-based fallback ensures system resilience
|
||||
|
||||
### Alternatives Considered
|
||||
1. Expanding the keyword lists:
|
||||
- Would still lack semantic understanding
|
||||
- Increasing maintenance burden
|
||||
- False positives would still occur
|
||||
|
||||
2. Using embedding similarity to predefined domain descriptions:
|
||||
- Potentially more computationally expensive
|
||||
- Less explainable than the LLM's reasoning
|
||||
- Would require managing embedding models
|
||||
|
||||
3. Creating a custom classifier:
|
||||
- Would require labeled training data
|
||||
- More development effort
|
||||
- Less flexible than the LLM approach
|
||||
|
||||
### Impact
|
||||
- More accurate query classification, especially for ambiguous or multi-domain queries
|
||||
- Reduction in maintenance overhead for keyword lists
|
||||
- Better search engine selection based on query domains
|
||||
- Improved report generation due to more accurate query understanding
|
||||
- Enhanced debugging capabilities with classification reasoning
|
|
@ -0,0 +1,312 @@
|
|||
# FastAPI Implementation Plan for Sim-Search (COMPLETED)
|
||||
|
||||
## Overview
|
||||
|
||||
This document outlines the plan for implementing a FastAPI backend for the sim-search project, replacing the current Gradio interface while maintaining all existing functionality. The API will serve as the backend for a new React frontend, providing a more flexible and powerful user experience.
|
||||
|
||||
✅ **Implementation Status: COMPLETED on March 20, 2025**
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **API Layer** ✅
|
||||
- FastAPI application with RESTful endpoints
|
||||
- OpenAPI documentation
|
||||
- Authentication middleware
|
||||
- CORS configuration
|
||||
|
||||
2. **Service Layer** ✅
|
||||
- Bridge between API and existing sim-search functionality
|
||||
- Handles async/sync coordination
|
||||
- Implements caching and optimization strategies
|
||||
|
||||
3. **Data Layer** ✅
|
||||
- SQLAlchemy ORM models
|
||||
- Database session management
|
||||
- Migration scripts using Alembic
|
||||
|
||||
4. **Authentication System** ✅
|
||||
- JWT-based authentication
|
||||
- User management
|
||||
- Role-based access control
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
sim-search-api/
|
||||
├── app/
|
||||
│ ├── api/
|
||||
│ │ ├── routes/
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ ├── query.py # Query processing endpoints
|
||||
│ │ │ ├── search.py # Search execution endpoints
|
||||
│ │ │ ├── report.py # Report generation endpoints
|
||||
│ │ │ └── auth.py # Authentication endpoints
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── dependencies.py # API dependencies (auth, rate limiting)
|
||||
│ ├── core/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── config.py # API configuration
|
||||
│ │ └── security.py # Security utilities
|
||||
│ ├── db/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── session.py # Database session
|
||||
│ │ └── models.py # Database models for reports, searches
|
||||
│ ├── schemas/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── token.py # Token schemas
|
||||
│ │ ├── user.py # User schemas
|
||||
│ │ ├── query.py # Query schemas
|
||||
│ │ ├── search.py # Search result schemas
|
||||
│ │ └── report.py # Report schemas
|
||||
│ ├── services/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── query_service.py # Query processing service
|
||||
│ │ ├── search_service.py # Search execution service
|
||||
│ │ └── report_service.py # Report generation service
|
||||
│ └── main.py # FastAPI application
|
||||
├── alembic/ # Database migrations
|
||||
│ ├── versions/
|
||||
│ │ └── 001_initial_migration.py # Initial migration
|
||||
│ ├── env.py # Alembic environment
|
||||
│ └── script.py.mako # Alembic script template
|
||||
├── .env.example # Environment variables template
|
||||
├── alembic.ini # Alembic configuration
|
||||
├── requirements.txt # API dependencies
|
||||
├── run.py # Script to run the API
|
||||
└── README.md # API documentation
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Authentication Endpoints ✅
|
||||
- `POST /api/v1/auth/token`: Get an authentication token
|
||||
- `POST /api/v1/auth/register`: Register a new user
|
||||
|
||||
### Query Processing Endpoints ✅
|
||||
- `POST /api/v1/query/process`: Process and enhance a user query
|
||||
- `POST /api/v1/query/classify`: Classify a query by type and intent
|
||||
|
||||
### Search Execution Endpoints ✅
|
||||
- `POST /api/v1/search/execute`: Execute a search with optional parameters
|
||||
- `GET /api/v1/search/engines`: Get available search engines
|
||||
- `GET /api/v1/search/history`: Get user's search history
|
||||
- `GET /api/v1/search/{search_id}`: Get results for a specific search
|
||||
- `DELETE /api/v1/search/{search_id}`: Delete a search from history
|
||||
|
||||
### Report Generation Endpoints ✅
|
||||
- `POST /api/v1/report/generate`: Generate a report from search results
|
||||
- `GET /api/v1/report/list`: Get a list of user's reports
|
||||
- `GET /api/v1/report/{report_id}`: Get a specific report
|
||||
- `DELETE /api/v1/report/{report_id}`: Delete a report
|
||||
- `GET /api/v1/report/{report_id}/download`: Download a report in specified format
|
||||
- `GET /api/v1/report/{report_id}/progress`: Get the progress of a report generation
|
||||
|
||||
## Database Models
|
||||
|
||||
### User Model ✅
|
||||
```python
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
email = Column(String, unique=True, index=True, nullable=False)
|
||||
hashed_password = Column(String, nullable=False)
|
||||
full_name = Column(String, nullable=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_superuser = Column(Boolean, default=False)
|
||||
```
|
||||
|
||||
### Search Model ✅
|
||||
```python
|
||||
class Search(Base):
|
||||
__tablename__ = "searches"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
user_id = Column(String, ForeignKey("users.id"))
|
||||
query = Column(String, nullable=False)
|
||||
enhanced_query = Column(String, nullable=True)
|
||||
query_type = Column(String, nullable=True)
|
||||
engines = Column(String, nullable=True) # Comma-separated list
|
||||
results_count = Column(Integer, default=0)
|
||||
results = Column(JSON, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
```
|
||||
|
||||
### Report Model ✅
|
||||
```python
|
||||
class Report(Base):
|
||||
__tablename__ = "reports"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
user_id = Column(String, ForeignKey("users.id"))
|
||||
search_id = Column(String, ForeignKey("searches.id"), nullable=True)
|
||||
title = Column(String, nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
detail_level = Column(String, nullable=False, default="standard")
|
||||
query_type = Column(String, nullable=True)
|
||||
model_used = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow)
|
||||
```
|
||||
|
||||
## Service Layer Integration
|
||||
|
||||
### Integration Strategy ✅
|
||||
|
||||
The service layer acts as a bridge between the API endpoints and the existing sim-search functionality. Each service:
|
||||
|
||||
1. Imports the corresponding sim-search components
|
||||
2. Adapts the API request to the format expected by sim-search
|
||||
3. Calls the sim-search functionality
|
||||
4. Transforms the result to the API response format
|
||||
|
||||
Example from the implemented QueryService:
|
||||
|
||||
```python
|
||||
# Add sim-search to the python path
|
||||
sim_search_path = Path(settings.SIM_SEARCH_PATH)
|
||||
sys.path.append(str(sim_search_path))
|
||||
|
||||
# Import sim-search components
|
||||
from query.query_processor import QueryProcessor
|
||||
from query.llm_interface import LLMInterface
|
||||
|
||||
class QueryService:
|
||||
def __init__(self):
|
||||
self.query_processor = QueryProcessor()
|
||||
self.llm_interface = LLMInterface()
|
||||
|
||||
async def process_query(self, query: str) -> Dict[str, Any]:
|
||||
# Process the query using the sim-search query processor
|
||||
structured_query = await self.query_processor.process_query(query)
|
||||
|
||||
# Format the response
|
||||
return {
|
||||
"original_query": query,
|
||||
"structured_query": structured_query
|
||||
}
|
||||
```
|
||||
|
||||
## Authentication System
|
||||
|
||||
### JWT-Based Authentication ✅
|
||||
|
||||
The authentication system uses JSON Web Tokens (JWT) to manage user sessions:
|
||||
|
||||
1. User logs in with email and password
|
||||
2. Server validates credentials and generates a JWT token
|
||||
3. Token is included in subsequent requests in the Authorization header
|
||||
4. Server validates the token for each protected endpoint
|
||||
|
||||
Implementation using FastAPI's dependencies:
|
||||
|
||||
```python
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/token")
|
||||
|
||||
def get_current_user(
|
||||
db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)
|
||||
) -> models.User:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
token_data = TokenPayload(**payload)
|
||||
except (JWTError, ValidationError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Could not validate credentials",
|
||||
)
|
||||
user = db.query(models.User).filter(models.User.id == token_data.sub).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return user
|
||||
```
|
||||
|
||||
## Implementation Phases
|
||||
|
||||
### Phase 1: Core Setup ✅
|
||||
- Set up project structure
|
||||
- Implement database models and migrations
|
||||
- Create authentication system
|
||||
- Implement configuration management
|
||||
|
||||
### Phase 2: Service Layer ✅
|
||||
- Implement query service integration
|
||||
- Implement search service integration
|
||||
- Implement report service integration
|
||||
- Add error handling and logging
|
||||
|
||||
### Phase 3: API Endpoints ✅
|
||||
- Implement authentication endpoints
|
||||
- Implement query processing endpoints
|
||||
- Implement search execution endpoints
|
||||
- Implement report generation endpoints
|
||||
|
||||
### Phase 4: Testing and Documentation ✅
|
||||
- Generate API documentation
|
||||
- Create user documentation
|
||||
|
||||
### Phase 5: Deployment and Integration ⏳
|
||||
- Set up deployment configuration
|
||||
- Configure environment variables
|
||||
- Integrate with React frontend
|
||||
- Perform end-to-end testing
|
||||
|
||||
## Dependencies
|
||||
|
||||
```
|
||||
# FastAPI and ASGI server
|
||||
fastapi==0.103.1
|
||||
uvicorn==0.23.2
|
||||
|
||||
# Database
|
||||
sqlalchemy==2.0.21
|
||||
alembic==1.12.0
|
||||
|
||||
# Authentication
|
||||
python-jose==3.3.0
|
||||
passlib==1.7.4
|
||||
bcrypt==4.0.1
|
||||
python-multipart==0.0.6
|
||||
|
||||
# Validation and serialization
|
||||
pydantic==2.4.2
|
||||
email-validator==2.0.0
|
||||
|
||||
# Testing
|
||||
pytest==7.4.2
|
||||
httpx==0.25.0
|
||||
|
||||
# Utilities
|
||||
python-dotenv==1.0.0
|
||||
aiofiles==23.2.1
|
||||
jinja2==3.1.2
|
||||
|
||||
# Report generation
|
||||
markdown==3.4.4
|
||||
weasyprint==60.1 # Optional, for PDF generation
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Test the FastAPI implementation to ensure it works correctly with the existing sim-search functionality
|
||||
2. Create a React frontend to consume the FastAPI backend
|
||||
3. Implement user management in the frontend
|
||||
4. Add search history and report management to the frontend
|
||||
5. Implement real-time progress tracking for report generation in the frontend
|
||||
6. Add visualization components for reports in the frontend
|
||||
7. Run comprehensive tests to ensure all functionality works with the new API
|
||||
8. Update any remaining documentation to reflect the new API
|
||||
9. Consider adding more API endpoints for additional functionality
|
||||
|
||||
## Conclusion
|
||||
|
||||
The FastAPI backend for the sim-search project has been successfully implemented according to this plan. The implementation provides a modern, maintainable, and scalable API that preserves all the functionality of the existing system while enabling new features and improvements through the planned React frontend.
|
||||
|
||||
The service layer pattern ensures a clean separation between the API and the existing sim-search functionality, making it easier to maintain and extend both components independently. This architecture also allows for future enhancements such as caching, background processing, and additional integrations without requiring major changes to the existing code.
|
||||
|
||||
The next phase of the project will focus on creating a React frontend to consume this API, providing a more flexible and powerful user experience.
|
|
@ -0,0 +1,397 @@
|
|||
# LLM-Based Query Classification Implementation Plan
|
||||
|
||||
## Overview
|
||||
|
||||
This document outlines a plan to replace the current keyword-based query classification system with an LLM-based approach. The current system uses predefined keyword lists to determine if a query is academic, code-related, or about current events. This approach is limited by the static nature of the keywords and doesn't capture the semantic meaning of queries. Switching to an LLM-based classification will provide more accurate and adaptable query typing.
|
||||
|
||||
## Current Limitations
|
||||
|
||||
1. **Keyword Dependency**:
|
||||
- The system relies on static lists of keywords that need constant updating
|
||||
- Many relevant terms are likely to be missing, especially for emerging topics
|
||||
- Some words have different meanings in different contexts (e.g., "model" can refer to code or academic concepts)
|
||||
|
||||
2. **False Classifications**:
|
||||
- Queries about LLMs being incorrectly classified as code-related instead of academic
|
||||
- General queries potentially being misclassified if they happen to contain certain keywords
|
||||
- No way to handle queries that span multiple categories
|
||||
|
||||
3. **Maintenance Burden**:
|
||||
- Need to regularly update keyword lists for each category
|
||||
- Complex if/then logic to determine query types
|
||||
- Hard to adapt to new research domains or technologies
|
||||
|
||||
## Proposed Solution
|
||||
|
||||
Replace the keyword-based classification with an LLM-based classification that:
|
||||
1. Uses semantic understanding to determine query intent and domain
|
||||
2. Can classify queries into multiple categories with confidence scores
|
||||
3. Provides reasoning for the classification
|
||||
4. Can adapt to new topics without code changes
|
||||
|
||||
## Technical Implementation
|
||||
|
||||
### 1. Extend LLM Interface with Domain Classification
|
||||
|
||||
Add a new method to the `LLMInterface` class in `query/llm_interface.py`:
|
||||
|
||||
```python
|
||||
async def classify_query_domain(self, query: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Classify a query's domain type (academic, code, current_events, general).
|
||||
|
||||
Args:
|
||||
query: The query to classify
|
||||
|
||||
Returns:
|
||||
Dictionary with query domain type and confidence scores
|
||||
"""
|
||||
# Get the model assigned to this function
|
||||
model_name = self.config.get_module_model('query_processing', 'classify_query_domain')
|
||||
|
||||
# Create a new interface with the assigned model if different from current
|
||||
if model_name != self.model_name:
|
||||
interface = LLMInterface(model_name)
|
||||
return await interface._classify_query_domain_impl(query)
|
||||
|
||||
return await self._classify_query_domain_impl(query)
|
||||
|
||||
async def _classify_query_domain_impl(self, query: str) -> Dict[str, Any]:
|
||||
"""Implementation of query domain classification."""
|
||||
messages = [
|
||||
{"role": "system", "content": """You are an expert query classifier.
|
||||
Analyze the given query and classify it into the following domain types:
|
||||
- academic: Related to scholarly research, scientific studies, academic papers, formal theories, university-level research topics, or scholarly fields of study
|
||||
- code: Related to programming, software development, technical implementation, coding languages, frameworks, or technology implementation questions
|
||||
- current_events: Related to recent news, ongoing developments, time-sensitive information, current politics, breaking stories, or real-time events
|
||||
- general: General information seeking that doesn't fit the above categories
|
||||
|
||||
You may assign multiple types if the query spans several domains.
|
||||
|
||||
Respond with a JSON object containing:
|
||||
{
|
||||
"primary_type": "the most appropriate type",
|
||||
"confidence": 0.X,
|
||||
"secondary_types": [{"type": "another_applicable_type", "confidence": 0.X}, ...],
|
||||
"reasoning": "brief explanation of your classification"
|
||||
}
|
||||
"""},
|
||||
{"role": "user", "content": query}
|
||||
]
|
||||
|
||||
# Generate classification
|
||||
response = await self.generate_completion(messages)
|
||||
|
||||
# Parse JSON response
|
||||
try:
|
||||
classification = json.loads(response)
|
||||
return classification
|
||||
except json.JSONDecodeError:
|
||||
# Fallback to default classification if parsing fails
|
||||
print(f"Error parsing domain classification response: {response}")
|
||||
return {
|
||||
"primary_type": "general",
|
||||
"confidence": 0.5,
|
||||
"secondary_types": [],
|
||||
"reasoning": "Failed to parse classification response"
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Update QueryProcessor Class
|
||||
|
||||
Modify the `QueryProcessor` class in `query/query_processor.py` to use the new LLM-based classification:
|
||||
|
||||
```python
|
||||
async def process_query(self, query: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a user query.
|
||||
|
||||
Args:
|
||||
query: The raw user query
|
||||
|
||||
Returns:
|
||||
Dictionary containing the processed query information
|
||||
"""
|
||||
logger.info(f"Processing query: {query}")
|
||||
|
||||
# Enhance the query
|
||||
enhanced_query = await self.llm_interface.enhance_query(query)
|
||||
logger.info(f"Enhanced query: {enhanced_query}")
|
||||
|
||||
# Classify the query type (factual, exploratory, comparative)
|
||||
query_type_classification = await self.llm_interface.classify_query(query)
|
||||
logger.info(f"Query type classification: {query_type_classification}")
|
||||
|
||||
# Classify the query domain (academic, code, current_events, general)
|
||||
domain_classification = await self.llm_interface.classify_query_domain(query)
|
||||
logger.info(f"Query domain classification: {domain_classification}")
|
||||
|
||||
# Extract entities from the classification
|
||||
entities = query_type_classification.get('entities', [])
|
||||
|
||||
# Structure the query using the new classification approach
|
||||
structured_query = self._structure_query_with_llm(
|
||||
query,
|
||||
enhanced_query,
|
||||
query_type_classification,
|
||||
domain_classification
|
||||
)
|
||||
|
||||
# Decompose the query into sub-questions (if complex enough)
|
||||
structured_query = await self.query_decomposer.decompose_query(query, structured_query)
|
||||
|
||||
# Log the number of sub-questions if any
|
||||
if 'sub_questions' in structured_query and structured_query['sub_questions']:
|
||||
logger.info(f"Decomposed into {len(structured_query['sub_questions'])} sub-questions")
|
||||
else:
|
||||
logger.info("Query was not decomposed into sub-questions")
|
||||
|
||||
return structured_query
|
||||
|
||||
def _structure_query_with_llm(self, original_query: str, enhanced_query: str,
|
||||
type_classification: Dict[str, Any],
|
||||
domain_classification: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Structure a query using LLM classification results.
|
||||
|
||||
Args:
|
||||
original_query: The original user query
|
||||
enhanced_query: The enhanced query
|
||||
type_classification: Classification of query type (factual, exploratory, comparative)
|
||||
domain_classification: Classification of query domain (academic, code, current_events)
|
||||
|
||||
Returns:
|
||||
Dictionary containing the structured query
|
||||
"""
|
||||
# Get primary domain and confidence
|
||||
primary_domain = domain_classification.get('primary_type', 'general')
|
||||
primary_confidence = domain_classification.get('confidence', 0.5)
|
||||
|
||||
# Get secondary domains
|
||||
secondary_domains = domain_classification.get('secondary_types', [])
|
||||
|
||||
# Determine domain flags
|
||||
is_academic = primary_domain == 'academic' or any(d['type'] == 'academic' for d in secondary_domains)
|
||||
is_code = primary_domain == 'code' or any(d['type'] == 'code' for d in secondary_domains)
|
||||
is_current_events = primary_domain == 'current_events' or any(d['type'] == 'current_events' for d in secondary_domains)
|
||||
|
||||
# Higher threshold for secondary domains to avoid false positives
|
||||
if primary_domain != 'academic' and any(d['type'] == 'academic' and d['confidence'] >= 0.3 for d in secondary_domains):
|
||||
is_academic = True
|
||||
|
||||
if primary_domain != 'code' and any(d['type'] == 'code' and d['confidence'] >= 0.3 for d in secondary_domains):
|
||||
is_code = True
|
||||
|
||||
if primary_domain != 'current_events' and any(d['type'] == 'current_events' and d['confidence'] >= 0.3 for d in secondary_domains):
|
||||
is_current_events = True
|
||||
|
||||
return {
|
||||
'original_query': original_query,
|
||||
'enhanced_query': enhanced_query,
|
||||
'type': type_classification.get('type', 'unknown'),
|
||||
'intent': type_classification.get('intent', 'research'),
|
||||
'entities': type_classification.get('entities', []),
|
||||
'domain': primary_domain,
|
||||
'domain_confidence': primary_confidence,
|
||||
'secondary_domains': secondary_domains,
|
||||
'classification_reasoning': domain_classification.get('reasoning', ''),
|
||||
'timestamp': None, # Will be filled in by the caller
|
||||
'is_current_events': is_current_events,
|
||||
'is_academic': is_academic,
|
||||
'is_code': is_code,
|
||||
'metadata': {
|
||||
'type_classification': type_classification,
|
||||
'domain_classification': domain_classification
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Remove Legacy Keyword-Based Classification Methods
|
||||
|
||||
Once the new LLM-based classification is working correctly, remove or deprecate the old keyword-based methods:
|
||||
- `_is_current_events_query`
|
||||
- `_is_academic_query`
|
||||
- `_is_code_query`
|
||||
|
||||
And the original `_structure_query` method.
|
||||
|
||||
### 4. Update Search Executor Integration
|
||||
|
||||
The `SearchExecutor` class already looks for the flags in the structured query:
|
||||
- `is_academic`
|
||||
- `is_code`
|
||||
- `is_current_events`
|
||||
|
||||
So no changes are needed to the `execute_search` method. The improved classification will simply provide more accurate flags.
|
||||
|
||||
### 5. Update Configuration
|
||||
|
||||
Add the new `classify_query_domain` function to the module model configuration to allow different models to be assigned to this function:
|
||||
|
||||
```yaml
|
||||
module_models:
|
||||
query_processing:
|
||||
enhance_query: llama-3.1-8b-instant # Fast model for query enhancement
|
||||
classify_query: llama-3.1-8b-instant # Fast model for query type classification
|
||||
classify_query_domain: llama-3.1-8b-instant # Fast model for domain classification
|
||||
generate_search_queries: llama-3.1-8b-instant # Fast model for search query generation
|
||||
```
|
||||
|
||||
### 6. Testing Plan
|
||||
|
||||
1. **Unit Tests**:
|
||||
- Create test cases for `classify_query_domain` with various query types
|
||||
- Verify correct classification of academic, code, and current events queries
|
||||
- Test edge cases and queries that span multiple domains
|
||||
|
||||
2. **Integration Tests**:
|
||||
- Test the full query processing pipeline with the new classification
|
||||
- Verify that the correct search engines are selected based on the classification
|
||||
- Compare results with the old keyword-based approach
|
||||
|
||||
3. **Regression Testing**:
|
||||
- Ensure that all existing functionality works with the new classification
|
||||
- Verify that no existing test cases fail
|
||||
|
||||
### 7. Logging and Monitoring
|
||||
|
||||
Add detailed logging to monitor the performance of the new classification:
|
||||
|
||||
```python
|
||||
logger.info(f"Query domain classification: primary={domain_classification.get('primary_type')} confidence={domain_classification.get('confidence')}")
|
||||
if domain_classification.get('secondary_types'):
|
||||
for sec_type in domain_classification.get('secondary_types'):
|
||||
logger.info(f"Secondary domain: {sec_type['type']} confidence={sec_type['confidence']}")
|
||||
logger.info(f"Classification reasoning: {domain_classification.get('reasoning', 'None provided')}")
|
||||
```
|
||||
|
||||
### 8. Fallback Mechanism
|
||||
|
||||
Implement a fallback to the keyword-based approach if the LLM classification fails:
|
||||
|
||||
```python
|
||||
try:
|
||||
domain_classification = await self.llm_interface.classify_query_domain(query)
|
||||
structured_query = self._structure_query_with_llm(query, enhanced_query, query_type_classification, domain_classification)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM domain classification failed: {e}. Falling back to keyword-based classification.")
|
||||
# Fallback to keyword-based approach
|
||||
structured_query = self._structure_query(query, enhanced_query, query_type_classification)
|
||||
```
|
||||
|
||||
## Timeline and Resources
|
||||
|
||||
### Phase 1: Development (2-3 days)
|
||||
- Implement the new `classify_query_domain` method in `LLMInterface`
|
||||
- Create the new `_structure_query_with_llm` method in `QueryProcessor`
|
||||
- Update the `process_query` method to use the new approach
|
||||
- Add configuration for the new function
|
||||
|
||||
### Phase 2: Testing (1-2 days)
|
||||
- Create test cases for the new classification
|
||||
- Test with various query types
|
||||
- Compare with the old approach
|
||||
|
||||
### Phase 3: Deployment and Monitoring (1 day)
|
||||
- Deploy the new version
|
||||
- Monitor logs for classification issues
|
||||
- Adjust prompts and thresholds as needed
|
||||
|
||||
### Phase 4: Cleanup (1 day)
|
||||
- Remove the old keyword-based methods
|
||||
- Update documentation
|
||||
|
||||
## Expected Outcomes
|
||||
|
||||
1. **Improved Classification Accuracy**:
|
||||
- More accurate identification of academic, code, and current events queries
|
||||
- Better handling of queries that span multiple domains
|
||||
- Proper classification of queries about emerging topics (like LLMs)
|
||||
|
||||
2. **Reduced Maintenance**:
|
||||
- No need to update keyword lists
|
||||
- Adaptability to new domains without code changes
|
||||
|
||||
3. **Enhanced User Experience**:
|
||||
- More relevant search results
|
||||
- Better report generation due to proper query classification
|
||||
|
||||
4. **System Robustness**:
|
||||
- Graceful handling of edge cases
|
||||
- Better explanation of classification decisions
|
||||
- Proper confidence scoring for ambiguous queries
|
||||
|
||||
## Examples
|
||||
|
||||
To illustrate how the new approach would work, here are some examples:
|
||||
|
||||
### Example 1: Academic Query
|
||||
**Query**: "What are the technological, economic, and social implications of large language models in today's society?"
|
||||
|
||||
**Current Classification**: Might be misclassified as code-related due to "models"
|
||||
|
||||
**LLM Classification**:
|
||||
```json
|
||||
{
|
||||
"primary_type": "academic",
|
||||
"confidence": 0.9,
|
||||
"secondary_types": [
|
||||
{"type": "general", "confidence": 0.4}
|
||||
],
|
||||
"reasoning": "This query is asking about implications of LLMs across multiple domains (technological, economic, and social) which is a scholarly research topic that would be well-addressed by academic sources."
|
||||
}
|
||||
```
|
||||
|
||||
### Example 2: Code Query
|
||||
**Query**: "How do I implement a transformer model in PyTorch for text classification?"
|
||||
|
||||
**Current Classification**: Might be correctly classified as code-related due to "implement", "model", and "PyTorch"
|
||||
|
||||
**LLM Classification**:
|
||||
```json
|
||||
{
|
||||
"primary_type": "code",
|
||||
"confidence": 0.95,
|
||||
"secondary_types": [
|
||||
{"type": "academic", "confidence": 0.4}
|
||||
],
|
||||
"reasoning": "This is primarily a programming question about implementing a specific model in PyTorch, which is a coding framework. It has academic aspects since it relates to machine learning models, but the focus is on implementation."
|
||||
}
|
||||
```
|
||||
|
||||
### Example 3: Current Events Query
|
||||
**Query**: "What are the latest developments in the Ukraine conflict?"
|
||||
|
||||
**Current Classification**: Likely correct if "Ukraine" is in the current events entities list
|
||||
|
||||
**LLM Classification**:
|
||||
```json
|
||||
{
|
||||
"primary_type": "current_events",
|
||||
"confidence": 0.95,
|
||||
"secondary_types": [],
|
||||
"reasoning": "This query is asking about 'latest developments' in an ongoing conflict, which clearly indicates a focus on recent news and time-sensitive information."
|
||||
}
|
||||
```
|
||||
|
||||
### Example 4: Mixed Query
|
||||
**Query**: "How are LLMs being used to detect and prevent cyber attacks?"
|
||||
|
||||
**Current Classification**: Might have mixed signals from both academic and code keywords
|
||||
|
||||
**LLM Classification**:
|
||||
```json
|
||||
{
|
||||
"primary_type": "academic",
|
||||
"confidence": 0.7,
|
||||
"secondary_types": [
|
||||
{"type": "code", "confidence": 0.6},
|
||||
{"type": "current_events", "confidence": 0.3}
|
||||
],
|
||||
"reasoning": "This query relates to research on LLM applications in cybersecurity (academic), has technical implementation aspects (code), and could relate to recent developments in the field (current events). The primary focus appears to be on research and study of this application."
|
||||
}
|
||||
```
|
||||
|
||||
## Conclusion
|
||||
|
||||
Replacing the keyword-based classification with an LLM-based approach will significantly improve the accuracy and adaptability of the query classification system. This will lead to better search results and report generation, particularly for complex or multi-domain queries like those about large language models. The implementation can be completed in 5-7 days and will reduce ongoing maintenance work by eliminating the need to update keyword lists.
|
|
@ -0,0 +1,571 @@
|
|||
# React Frontend Implementation Plan for Sim-Search
|
||||
|
||||
## Overview
|
||||
|
||||
This document outlines the plan for implementing a React frontend for the sim-search project, replacing the current Gradio interface with a modern, responsive, and feature-rich user interface. The frontend will communicate with the new FastAPI backend to provide a seamless user experience.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **Next.js Framework**
|
||||
- Server-side rendering for improved SEO
|
||||
- API routes for backend proxying if needed
|
||||
- Static site generation for performance
|
||||
|
||||
2. **Component Library**
|
||||
- Modular React components
|
||||
- Reusable UI elements
|
||||
- Styling with Tailwind CSS
|
||||
|
||||
3. **State Management**
|
||||
- React Query for server state
|
||||
- Context API for application state
|
||||
- Form state management
|
||||
|
||||
4. **Authentication**
|
||||
- JWT token management
|
||||
- Protected routes
|
||||
- User profile management
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
sim-search-ui/
|
||||
├── src/
|
||||
│ ├── components/
|
||||
│ │ ├── layout/
|
||||
│ │ │ ├── Header.jsx # Application header
|
||||
│ │ │ ├── Sidebar.jsx # Sidebar menu
|
||||
│ │ │ └── Layout.jsx # Main layout wrapper
|
||||
│ │ ├── search/
|
||||
│ │ │ ├── SearchForm.jsx # Search input form
|
||||
│ │ │ ├── SearchResults.jsx # Results display
|
||||
│ │ │ ├── ResultItem.jsx # Individual result
|
||||
│ │ │ └── EngineSelector.jsx # Search engine selector
|
||||
│ │ ├── report/
|
||||
│ │ │ ├── ReportGenerator.jsx # Report generation form
|
||||
│ │ │ ├── ReportViewer.jsx # Report display
|
||||
│ │ │ ├── ReportsList.jsx # Reports list/management
|
||||
│ │ │ └── ReportOptions.jsx # Report generation options
|
||||
│ │ ├── common/
|
||||
│ │ │ ├── Button.jsx # Reusable button component
|
||||
│ │ │ ├── Card.jsx # Card container component
|
||||
│ │ │ ├── Loading.jsx # Loading indicator
|
||||
│ │ │ └── Modal.jsx # Modal dialog
|
||||
│ │ └── auth/
|
||||
│ │ ├── LoginForm.jsx # User login form
|
||||
│ │ └── RegisterForm.jsx # User registration form
|
||||
│ ├── hooks/
|
||||
│ │ ├── useAuth.js # Authentication hook
|
||||
│ │ ├── useSearch.js # Search execution hook
|
||||
│ │ └── useReport.js # Report management hook
|
||||
│ ├── context/
|
||||
│ │ ├── AuthContext.jsx # Authentication context
|
||||
│ │ └── SearchContext.jsx # Search state context
|
||||
│ ├── services/
|
||||
│ │ ├── api.js # API client service
|
||||
│ │ ├── auth.js # Authentication service
|
||||
│ │ ├── search.js # Search service
|
||||
│ │ └── report.js # Report service
|
||||
│ ├── utils/
|
||||
│ │ ├── formatting.js # Text/data formatting utilities
|
||||
│ │ └── validation.js # Form validation utilities
|
||||
│ ├── styles/
|
||||
│ │ ├── globals.css # Global styles
|
||||
│ │ └── theme.js # Theme configuration
|
||||
│ └── pages/
|
||||
│ ├── _app.jsx # App component
|
||||
│ ├── index.jsx # Home page
|
||||
│ ├── search.jsx # Search page
|
||||
│ ├── reports/
|
||||
│ │ ├── index.jsx # Reports list page
|
||||
│ │ ├── [id].jsx # Individual report page
|
||||
│ │ └── new.jsx # New report page
|
||||
│ └── auth/
|
||||
│ ├── login.jsx # Login page
|
||||
│ └── register.jsx # Registration page
|
||||
├── public/
|
||||
│ ├── logo.svg # Application logo
|
||||
│ └── favicon.ico # Favicon
|
||||
├── tailwind.config.js # Tailwind configuration
|
||||
├── next.config.js # Next.js configuration
|
||||
└── package.json # Dependencies
|
||||
```
|
||||
|
||||
## Key Pages and Features
|
||||
|
||||
### Home Page
|
||||
- Overview of the system
|
||||
- Quick access to search and reports
|
||||
- Feature highlights and documentation
|
||||
|
||||
### Search Page
|
||||
- Comprehensive search form
|
||||
- Multiple search engine selection
|
||||
- Advanced search options
|
||||
- Results display with filtering and sorting
|
||||
- Options to generate reports from results
|
||||
|
||||
### Report Generation Page
|
||||
- Detail level selection
|
||||
- Query type selection
|
||||
- Model selection
|
||||
- Advanced options
|
||||
- Progress tracking
|
||||
|
||||
### Reports Management Page
|
||||
- List of generated reports
|
||||
- Filtering and sorting options
|
||||
- Download in different formats
|
||||
- Delete and manage reports
|
||||
|
||||
### Authentication Pages
|
||||
- Login page
|
||||
- Registration page
|
||||
- User profile management
|
||||
|
||||
## Component Design
|
||||
|
||||
### Search Components
|
||||
|
||||
#### SearchForm Component
|
||||
```jsx
|
||||
const SearchForm = ({ onSearchComplete }) => {
|
||||
const [query, setQuery] = useState('');
|
||||
const [selectedEngines, setSelectedEngines] = useState([]);
|
||||
const [numResults, setNumResults] = useState(10);
|
||||
const [useReranker, setUseReranker] = useState(true);
|
||||
const { engines, loading, error, loadEngines, search } = useSearch();
|
||||
|
||||
// Load available search engines on component mount
|
||||
useEffect(() => {
|
||||
loadEngines();
|
||||
}, []);
|
||||
|
||||
// Handle search submission
|
||||
const handleSubmit = async (e) => {
|
||||
e.preventDefault();
|
||||
|
||||
const searchParams = {
|
||||
query: query.trim(),
|
||||
search_engines: selectedEngines.length > 0 ? selectedEngines : undefined,
|
||||
num_results: numResults,
|
||||
use_reranker: useReranker,
|
||||
};
|
||||
|
||||
const results = await search(searchParams);
|
||||
|
||||
if (results && onSearchComplete) {
|
||||
onSearchComplete(results);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
// Form UI with input fields, engine selection, and options
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
#### SearchResults Component
|
||||
```jsx
|
||||
const SearchResults = ({ results, query, onGenerateReport }) => {
|
||||
const [selectedResults, setSelectedResults] = useState([]);
|
||||
const [sortBy, setSortBy] = useState('relevance');
|
||||
|
||||
// Toggle a result's selection
|
||||
const toggleResultSelection = (resultId) => {
|
||||
setSelectedResults(prev => (
|
||||
prev.includes(resultId)
|
||||
? prev.filter(id => id !== resultId)
|
||||
: [...prev, resultId]
|
||||
));
|
||||
};
|
||||
|
||||
// Handle generate report button click
|
||||
const handleGenerateReport = () => {
|
||||
// Filter results to only include selected ones if any are selected
|
||||
const resultsToUse = selectedResults.length > 0
|
||||
? results.filter((result, index) => selectedResults.includes(index))
|
||||
: results;
|
||||
|
||||
if (onGenerateReport) {
|
||||
onGenerateReport(resultsToUse, query);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
// Results UI with sorting, filtering, and item selection
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
### Report Components
|
||||
|
||||
#### ReportGenerator Component
|
||||
```jsx
|
||||
const ReportGenerator = ({ query, searchResults, searchId }) => {
|
||||
const [detailLevel, setDetailLevel] = useState('standard');
|
||||
const [queryType, setQueryType] = useState('auto-detect');
|
||||
const [customModel, setCustomModel] = useState('');
|
||||
const [initialResults, setInitialResults] = useState(10);
|
||||
const [finalResults, setFinalResults] = useState(7);
|
||||
const { loading, error, createReport } = useReport();
|
||||
|
||||
// Generate the report
|
||||
const handleGenerateReport = async () => {
|
||||
const reportParams = {
|
||||
query,
|
||||
search_id: searchId,
|
||||
search_results: !searchId ? searchResults : undefined,
|
||||
detail_level: detailLevel,
|
||||
query_type: queryType,
|
||||
custom_model: customModel || undefined,
|
||||
initial_results: initialResults,
|
||||
final_results: finalResults
|
||||
};
|
||||
|
||||
await createReport(reportParams);
|
||||
};
|
||||
|
||||
return (
|
||||
// Report generation form with options
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
#### ReportViewer Component
|
||||
```jsx
|
||||
const ReportViewer = ({ report, onDownload }) => {
|
||||
const [selectedFormat, setSelectedFormat] = useState('markdown');
|
||||
const { download, loading } = useReport();
|
||||
|
||||
const handleDownload = async () => {
|
||||
if (onDownload) {
|
||||
onDownload(report.id, selectedFormat);
|
||||
} else {
|
||||
await download(report.id, selectedFormat);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
// Report content display with markdown rendering and download options
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
## API Integration Services
|
||||
|
||||
### API Client Service
|
||||
```javascript
|
||||
import axios from 'axios';
|
||||
|
||||
// Create an axios instance with default config
|
||||
const api = axios.create({
|
||||
baseURL: process.env.NEXT_PUBLIC_API_URL || 'http://localhost:8000',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
});
|
||||
|
||||
// Add a request interceptor to include auth token in requests
|
||||
api.interceptors.request.use(
|
||||
(config) => {
|
||||
const token = localStorage.getItem('token');
|
||||
if (token) {
|
||||
config.headers.Authorization = `Bearer ${token}`;
|
||||
}
|
||||
return config;
|
||||
},
|
||||
(error) => Promise.reject(error)
|
||||
);
|
||||
|
||||
// Add a response interceptor to handle common errors
|
||||
api.interceptors.response.use(
|
||||
(response) => response,
|
||||
(error) => {
|
||||
// Handle 401 Unauthorized - redirect to login
|
||||
if (error.response && error.response.status === 401) {
|
||||
localStorage.removeItem('token');
|
||||
window.location.href = '/auth/login';
|
||||
}
|
||||
return Promise.reject(error);
|
||||
}
|
||||
);
|
||||
|
||||
export default api;
|
||||
```
|
||||
|
||||
### Search Service
|
||||
```javascript
|
||||
import api from './api';
|
||||
|
||||
export const executeSearch = async (searchParams) => {
|
||||
try {
|
||||
const response = await api.post('/api/search/execute', searchParams);
|
||||
return { success: true, data: response.data };
|
||||
} catch (error) {
|
||||
return {
|
||||
success: false,
|
||||
error: error.response?.data?.detail || 'Failed to execute search'
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
export const getAvailableEngines = async () => {
|
||||
try {
|
||||
const response = await api.get('/api/search/engines');
|
||||
return { success: true, data: response.data };
|
||||
} catch (error) {
|
||||
return {
|
||||
success: false,
|
||||
error: error.response?.data?.detail || 'Failed to get search engines'
|
||||
};
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
### Report Service
|
||||
```javascript
|
||||
import api from './api';
|
||||
|
||||
export const generateReport = async (reportParams) => {
|
||||
try {
|
||||
const response = await api.post('/api/report/generate', reportParams);
|
||||
return { success: true, data: response.data };
|
||||
} catch (error) {
|
||||
return {
|
||||
success: false,
|
||||
error: error.response?.data?.detail || 'Failed to generate report'
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
export const getReportsList = async (skip = 0, limit = 100) => {
|
||||
try {
|
||||
const response = await api.get(`/api/report/list?skip=${skip}&limit=${limit}`);
|
||||
return { success: true, data: response.data };
|
||||
} catch (error) {
|
||||
return {
|
||||
success: false,
|
||||
error: error.response?.data?.detail || 'Failed to get reports list'
|
||||
};
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
## Custom Hooks
|
||||
|
||||
### Authentication Hook
|
||||
```javascript
|
||||
import { useState, useEffect, useContext, createContext } from 'react';
|
||||
import { getCurrentUser, isAuthenticated } from '../services/auth';
|
||||
|
||||
// Create auth context
|
||||
const AuthContext = createContext(null);
|
||||
|
||||
// Auth provider component
|
||||
export const AuthProvider = ({ children }) => {
|
||||
const [user, setUser] = useState(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState(null);
|
||||
|
||||
useEffect(() => {
|
||||
// Check if user is authenticated and fetch user data
|
||||
const fetchUser = async () => {
|
||||
if (isAuthenticated()) {
|
||||
try {
|
||||
setLoading(true);
|
||||
const result = await getCurrentUser();
|
||||
if (result.success) {
|
||||
setUser(result.data);
|
||||
} else {
|
||||
setError(result.error);
|
||||
}
|
||||
} catch (err) {
|
||||
setError('Failed to fetch user data');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
} else {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetchUser();
|
||||
}, []);
|
||||
|
||||
// Return provider with auth context
|
||||
return (
|
||||
<AuthContext.Provider value={{ user, loading, error, setUser }}>
|
||||
{children}
|
||||
</AuthContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
// Custom hook to use auth context
|
||||
export const useAuth = () => {
|
||||
const context = useContext(AuthContext);
|
||||
if (context === null) {
|
||||
throw new Error('useAuth must be used within an AuthProvider');
|
||||
}
|
||||
return context;
|
||||
};
|
||||
```
|
||||
|
||||
### Search Hook
|
||||
```javascript
|
||||
import { useState } from 'react';
|
||||
import { executeSearch, getAvailableEngines } from '../services/search';
|
||||
|
||||
export const useSearch = () => {
|
||||
const [results, setResults] = useState([]);
|
||||
const [engines, setEngines] = useState([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [error, setError] = useState(null);
|
||||
|
||||
// Load available search engines
|
||||
const loadEngines = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
const result = await getAvailableEngines();
|
||||
if (result.success) {
|
||||
setEngines(result.data);
|
||||
} else {
|
||||
setError(result.error);
|
||||
}
|
||||
} catch (err) {
|
||||
setError('Failed to load search engines');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Execute a search
|
||||
const search = async (searchParams) => {
|
||||
try {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
const result = await executeSearch(searchParams);
|
||||
if (result.success) {
|
||||
setResults(result.data.results);
|
||||
return result.data;
|
||||
} else {
|
||||
setError(result.error);
|
||||
return null;
|
||||
}
|
||||
} catch (err) {
|
||||
setError('Failed to execute search');
|
||||
return null;
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
results,
|
||||
engines,
|
||||
loading,
|
||||
error,
|
||||
search,
|
||||
loadEngines,
|
||||
};
|
||||
};
|
||||
```
|
||||
|
||||
## Implementation Phases
|
||||
|
||||
### Phase 1: Project Setup & Core Components (Week 1)
|
||||
- Set up Next.js project
|
||||
- Configure Tailwind CSS
|
||||
- Implement common UI components
|
||||
- Create layout components
|
||||
|
||||
### Phase 2: Authentication & API Integration (Week 1-2)
|
||||
- Implement authentication components
|
||||
- Create API service layer
|
||||
- Implement custom hooks
|
||||
- Set up protected routes
|
||||
|
||||
### Phase 3: Search Functionality (Week 2)
|
||||
- Implement search form
|
||||
- Create search results display
|
||||
- Add filtering and sorting
|
||||
- Implement search engine selection
|
||||
|
||||
### Phase 4: Report Generation & Management (Week 2-3)
|
||||
- Implement report generation form
|
||||
- Create report viewer with markdown rendering
|
||||
- Add report management interface
|
||||
- Implement download functionality
|
||||
|
||||
### Phase 5: Testing & Refinement (Week 3)
|
||||
- Write component tests
|
||||
- Perform cross-browser testing
|
||||
- Add responsive design improvements
|
||||
- Optimize performance
|
||||
|
||||
### Phase 6: Deployment & Documentation (Week 3-4)
|
||||
- Set up deployment configuration
|
||||
- Create user documentation
|
||||
- Add inline help and tooltips
|
||||
- Perform final testing
|
||||
|
||||
## Dependencies
|
||||
|
||||
```json
|
||||
{
|
||||
"dependencies": {
|
||||
"next": "^13.5.4",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"axios": "^1.5.1",
|
||||
"react-markdown": "^9.0.0",
|
||||
"react-query": "^3.39.3",
|
||||
"tailwindcss": "^3.3.3",
|
||||
"postcss": "^8.4.31",
|
||||
"autoprefixer": "^10.4.16",
|
||||
"jose": "^4.14.6"
|
||||
},
|
||||
"devDependencies": {
|
||||
"eslint": "^8.51.0",
|
||||
"eslint-config-next": "^13.5.4",
|
||||
"typescript": "^5.2.2",
|
||||
"@types/react": "^18.2.28",
|
||||
"@types/node": "^20.8.6",
|
||||
"jest": "^29.7.0",
|
||||
"@testing-library/react": "^14.0.0",
|
||||
"@testing-library/jest-dom": "^6.1.4"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Accessibility Considerations
|
||||
|
||||
The React frontend will be built with accessibility in mind:
|
||||
|
||||
1. **Semantic HTML**: Use proper HTML elements for their intended purpose
|
||||
2. **ARIA Attributes**: Add ARIA attributes where necessary
|
||||
3. **Keyboard Navigation**: Ensure all interactive elements are keyboard accessible
|
||||
4. **Focus Management**: Properly manage focus, especially in modals and dialogs
|
||||
5. **Color Contrast**: Ensure sufficient color contrast for text and UI elements
|
||||
6. **Screen Reader Support**: Test with screen readers to ensure compatibility
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
To ensure optimal performance:
|
||||
|
||||
1. **Code Splitting**: Use Next.js code splitting to reduce initial bundle size
|
||||
2. **Lazy Loading**: Implement lazy loading for components not needed immediately
|
||||
3. **Memoization**: Use React.memo and useMemo to prevent unnecessary re-renders
|
||||
4. **Image Optimization**: Use Next.js image optimization for faster loading
|
||||
5. **API Response Caching**: Cache API responses with React Query
|
||||
6. **Bundle Analysis**: Regularly analyze bundle size to identify improvements
|
||||
|
||||
## Conclusion
|
||||
|
||||
This implementation plan provides a structured approach to creating a modern React frontend for the sim-search project. By following this plan, we will create a user-friendly, accessible, and feature-rich interface that leverages the power of the new FastAPI backend.
|
||||
|
||||
The component-based architecture ensures reusability and maintainability, while the use of modern React patterns and hooks simplifies state management and side effects. The integration with the FastAPI backend is handled through a clean service layer, making it easy to adapt to changes in the API.
|
||||
|
||||
With this implementation, users will have a much improved experience compared to the current Gradio interface, with better search capabilities, more advanced report generation options, and a more intuitive interface for managing their research.
|
|
@ -1,726 +1,313 @@
|
|||
# Session Log
|
||||
|
||||
## Session: 2025-03-17
|
||||
## Session: 2025-03-20 - API Testing Implementation
|
||||
|
||||
### Overview
|
||||
Fixed bugs in the UI progress callback mechanism for report generation, consolidated redundant progress indicators, and resolved LLM provider configuration issues with OpenRouter models.
|
||||
Created a comprehensive testing framework for the sim-search API, including automated tests with pytest, a test runner script, and a manual testing script using curl commands.
|
||||
|
||||
### Key Activities
|
||||
1. Identified and fixed an AttributeError in the report generation progress callback:
|
||||
- Diagnosed the issue: 'Textbox' object has no attribute 'update'
|
||||
- Fixed by replacing `update(value=...)` method calls with direct value assignment (`component.value = ...`)
|
||||
- Committed changes with message "Fix AttributeError in report progress callback by using direct value assignment instead of update method"
|
||||
- Updated memory bank documentation with the fix details
|
||||
1. **Created Automated API Tests**:
|
||||
- Implemented `test_api.py` with pytest to test all API endpoints
|
||||
- Created tests for authentication, query processing, search execution, and report generation
|
||||
- Set up test fixtures for database initialization and user authentication
|
||||
- Implemented test database isolation to avoid affecting production data
|
||||
|
||||
2. Enhanced the progress indicator to ensure UI updates during async operations:
|
||||
- Identified that the progress indicator wasn't updating in real-time despite fixing the AttributeError
|
||||
- Implemented a solution using Gradio's built-in progress tracking mechanism
|
||||
- Added `progress(current_progress, desc=status_message)` to leverage Gradio's internal update mechanisms
|
||||
- Tested the solution to confirm progress indicators now update properly during report generation
|
||||
2. **Developed Test Runner Script**:
|
||||
- Created `run_tests.py` to simplify running the tests
|
||||
- Added command-line options for verbosity, coverage reporting, and test selection
|
||||
- Implemented clear output formatting for test results
|
||||
|
||||
3. Consolidated redundant progress indicators in the UI:
|
||||
- Identified three separate progress indicators in the UI (Progress Status textbox, progress slider, and built-in Gradio progress bar)
|
||||
- Removed the redundant Progress Status textbox and progress slider components
|
||||
- Simplified the UI to use only Gradio's built-in progress tracking mechanism
|
||||
- Updated the progress callback to work exclusively with the built-in progress mechanism
|
||||
- Tested the changes to ensure a cleaner, more consistent user experience
|
||||
3. **Created Manual Testing Script**:
|
||||
- Implemented `test_api_curl.sh` for manual testing with curl commands
|
||||
- Added tests for all API endpoints with proper authentication
|
||||
- Implemented colorized output for better readability
|
||||
- Added error handling and dependency checks between tests
|
||||
|
||||
4. **Added Test Documentation**:
|
||||
- Created a README.md file for the tests directory
|
||||
- Documented how to run the tests using different methods
|
||||
- Added troubleshooting information for common issues
|
||||
|
||||
### Insights
|
||||
- Gradio Textbox and Slider components use direct value assignment for updates rather than an update method
|
||||
- Asynchronous operations in Gradio require special handling to ensure UI elements update in real-time
|
||||
- Using Gradio's built-in progress tracking mechanism is more effective than manual UI updates for async tasks
|
||||
- When using LiteLLM with different model providers, it's essential to set the `custom_llm_provider` parameter correctly for each provider
|
||||
|
||||
4. Fixed LLM provider configuration for OpenRouter models:
|
||||
- Identified an issue with OpenRouter models not working correctly in the report synthesis module
|
||||
- Added the missing `custom_llm_provider = 'openrouter'` parameter to the LiteLLM completion parameters
|
||||
- Tested the fix to ensure OpenRouter models now work correctly for report generation
|
||||
- The progress callback mechanism is critical for providing user feedback during long-running report generation tasks
|
||||
- Proper error handling in UI callbacks is essential for a smooth user experience
|
||||
- Simplifying the UI by removing redundant progress indicators improves user experience and reduces confusion
|
||||
- Consolidating to a single progress indicator ensures consistent feedback and reduces code complexity
|
||||
|
||||
|
||||
## Session: 2025-02-27
|
||||
|
||||
### Overview
|
||||
Initial project setup and implementation of core functionality for semantic similarity search using Jina AI's APIs.
|
||||
|
||||
### Key Activities
|
||||
1. Created the core `JinaSimilarity` class in jina_similarity.py with the following features:
|
||||
- Token counting using tiktoken
|
||||
- Embedding generation using Jina AI's Embeddings API
|
||||
- Similarity computation using cosine similarity
|
||||
- Error handling for token limit violations
|
||||
|
||||
2. Implemented the markdown segmenter in markdown_segmenter.py:
|
||||
- Segmentation of markdown documents using Jina AI's Segmenter API
|
||||
- Command-line interface for easy usage
|
||||
|
||||
3. Developed a test script (test_similarity.py) with:
|
||||
- Command-line argument parsing
|
||||
- File reading functionality
|
||||
- Verbose output option for debugging
|
||||
- Error handling
|
||||
|
||||
4. Created sample files for testing:
|
||||
- sample_chunk.txt: Contains a paragraph about pangrams
|
||||
- sample_query.txt: Contains a question about pangrams
|
||||
|
||||
### Insights
|
||||
- Jina AI's embedding model (jina-embeddings-v3) provides high-quality embeddings for semantic search
|
||||
- The token limit of 8,192 tokens is sufficient for most use cases, but longer documents need segmentation
|
||||
- Normalizing embeddings simplifies similarity computation (dot product equals cosine similarity)
|
||||
- Separating segmentation from similarity computation provides better modularity
|
||||
- The FastAPI TestClient provides a convenient way to test API endpoints without starting a server
|
||||
- Using a separate test database ensures that tests don't affect production data
|
||||
- Pytest fixtures are useful for setting up and tearing down test environments
|
||||
- Manual testing with curl commands is useful for debugging and understanding the API
|
||||
|
||||
### Challenges
|
||||
- Ensuring proper error handling for API failures
|
||||
- Managing token limits for large documents
|
||||
- Balancing between chunking granularity and semantic coherence
|
||||
- Ensuring proper authentication for all API endpoints
|
||||
- Managing dependencies between tests (e.g., needing a search ID to generate a report)
|
||||
- Setting up a clean test environment for each test run
|
||||
- Handling asynchronous operations in tests
|
||||
|
||||
### Next Steps
|
||||
1. Add tiktoken to requirements.txt
|
||||
2. Implement caching for embeddings to reduce API calls
|
||||
3. Add batch processing capabilities for multiple chunks/queries
|
||||
4. Create comprehensive documentation and usage examples
|
||||
5. Develop integration tests for reliability testing
|
||||
1. Run the tests to verify that the API is working correctly
|
||||
2. Fix any issues found during testing
|
||||
3. Add more specific tests for edge cases and error handling
|
||||
4. Integrate the tests into a CI/CD pipeline
|
||||
5. Add performance tests for the API
|
||||
6. Consider adding integration tests with the frontend
|
||||
|
||||
## Session: 2025-02-27 (Update)
|
||||
## Session: 2025-03-20 - FastAPI Backend Implementation
|
||||
|
||||
### Overview
|
||||
Created memory bank for the project to maintain persistent knowledge about the codebase and development progress.
|
||||
Implemented a FastAPI backend for the sim-search project, replacing the current Gradio interface while maintaining all existing functionality. The API will serve as the backend for a new React frontend, providing a more flexible and powerful user experience.
|
||||
|
||||
### Key Activities
|
||||
1. Created the `.note/` directory to store memory bank files
|
||||
2. Created the following memory bank files:
|
||||
- project_overview.md: Purpose, goals, and high-level architecture
|
||||
- current_focus.md: Active work, recent changes, and next steps
|
||||
- development_standards.md: Coding conventions and patterns
|
||||
- decision_log.md: Key decisions with rationale
|
||||
- code_structure.md: Codebase organization with module descriptions
|
||||
- session_log.md: History of development sessions
|
||||
- interfaces.md: Component interfaces and API documentation
|
||||
1. **Created Directory Structure**:
|
||||
- Set up project structure following the implementation plan in `fastapi_implementation_plan.md`
|
||||
- Created directories for API routes, core functionality, database models, schemas, and services
|
||||
- Added proper `__init__.py` files to make all directories proper Python packages
|
||||
|
||||
2. **Implemented Core Components**:
|
||||
- Created FastAPI application with configuration and security
|
||||
- Implemented database models for users, searches, and reports
|
||||
- Set up database migrations with Alembic
|
||||
- Created API routes for authentication, query processing, search execution, and report generation
|
||||
- Implemented service layer to bridge between API and existing sim-search functionality
|
||||
- Added JWT-based authentication
|
||||
- Created comprehensive documentation for the API
|
||||
- Added environment variable configuration
|
||||
- Implemented OpenAPI documentation endpoints
|
||||
|
||||
3. **Created Service Layer**:
|
||||
- Implemented `QueryService` to bridge between API and existing query processing functionality
|
||||
- Created `SearchService` to handle search execution and result management
|
||||
- Implemented `ReportService` for report generation and management
|
||||
- Added proper error handling and logging throughout the service layer
|
||||
- Ensured asynchronous operation for all services
|
||||
|
||||
4. **Set Up Database**:
|
||||
- Created SQLAlchemy models for users, searches, and reports
|
||||
- Implemented database session management
|
||||
- Set up Alembic for database migrations
|
||||
- Created initial migration script to create all tables
|
||||
|
||||
### Insights
|
||||
- The project has a clear structure with well-defined components
|
||||
- The use of Jina AI's APIs provides powerful semantic search capabilities
|
||||
- The modular design allows for easy extension and maintenance
|
||||
- Some improvements are needed, such as adding tiktoken to requirements.txt
|
||||
|
||||
### Next Steps
|
||||
1. Update requirements.txt to include all dependencies (tiktoken)
|
||||
2. Implement caching mechanism for embeddings
|
||||
3. Add batch processing capabilities
|
||||
4. Create comprehensive documentation
|
||||
5. Develop integration tests
|
||||
|
||||
## Session: 2025-02-27 (Update 2)
|
||||
|
||||
### Overview
|
||||
Expanded the project scope to build a comprehensive intelligent research system with an 8-stage pipeline.
|
||||
|
||||
### Key Activities
|
||||
1. Defined the overall architecture for the intelligent research system:
|
||||
- 8-stage pipeline from query acceptance to report generation
|
||||
- Multiple search sources (Google, Serper, Jina Search, Google Scholar, arXiv)
|
||||
- Semantic processing using Jina AI's APIs
|
||||
|
||||
2. Updated the memory bank to reflect the broader vision:
|
||||
- Revised project_overview.md with the complete research system goals
|
||||
- Updated current_focus.md with next steps for each pipeline stage
|
||||
- Enhanced code_structure.md with planned project organization
|
||||
- Added new decisions to decision_log.md
|
||||
|
||||
### Insights
|
||||
- The modular pipeline architecture allows for incremental development
|
||||
- Jina AI's suite of APIs provides a consistent approach to semantic processing
|
||||
- Multiple search sources will provide more comprehensive research results
|
||||
- The current similarity components fit naturally into stages 6-7 of the pipeline
|
||||
|
||||
### Next Steps
|
||||
1. Begin implementing the query processing module (stage 1)
|
||||
2. Design the data structures for passing information between pipeline stages
|
||||
3. Create a project roadmap with milestones for each stage
|
||||
4. Prioritize development of core components for an end-to-end MVP
|
||||
|
||||
## Session: 2025-02-27 (Update 3)
|
||||
|
||||
### Overview
|
||||
Planned the implementation of the Query Processing Module with LiteLLM integration and Gradio UI.
|
||||
|
||||
### Key Activities
|
||||
1. Researched LiteLLM integration:
|
||||
- Explored LiteLLM documentation and usage patterns
|
||||
- Investigated integration with Gradio for UI development
|
||||
- Identified configuration requirements and best practices
|
||||
|
||||
2. Developed implementation plan:
|
||||
- Prioritized Query Processing Module with LiteLLM integration
|
||||
- Planned Gradio UI implementation for user interaction
|
||||
- Outlined configuration structure for API keys and settings
|
||||
- Established a sequence for implementing remaining modules
|
||||
|
||||
3. Updated memory bank:
|
||||
- Revised current_focus.md with new implementation plan
|
||||
- Added immediate and future steps for development
|
||||
|
||||
### Insights
|
||||
- LiteLLM provides a unified interface to multiple LLM providers, simplifying integration
|
||||
- Gradio offers an easy way to create interactive UIs for AI applications
|
||||
- The modular approach allows for incremental development and testing
|
||||
- Existing similarity components can be integrated into the pipeline at a later stage
|
||||
|
||||
### Next Steps
|
||||
1. Update requirements.txt with new dependencies (litellm, gradio, etc.)
|
||||
2. Create configuration structure for secure API key management
|
||||
3. Implement LiteLLM interface for query enhancement and classification
|
||||
4. Develop the query processor with structured output
|
||||
5. Build the Gradio UI for user interaction
|
||||
|
||||
## Session: 2025-02-27 (Update 4)
|
||||
|
||||
### Overview
|
||||
Implemented module-specific model configuration and created the Jina AI Reranker module.
|
||||
|
||||
### Key Activities
|
||||
1. Enhanced configuration structure:
|
||||
- Added support for module-specific model assignments
|
||||
- Configured different models for different tasks
|
||||
- Added detailed endpoint configurations for various providers
|
||||
|
||||
2. Updated LLMInterface:
|
||||
- Modified to support module-specific model configurations
|
||||
- Added support for different endpoint types (OpenAI, Azure, Ollama)
|
||||
- Implemented method delegation to use appropriate models for each task
|
||||
|
||||
3. Created Jina AI Reranker module:
|
||||
- Implemented document reranking using Jina AI's Reranker API
|
||||
- Added support for reranking documents with metadata
|
||||
- Configured to use the "jina-reranker-v2-base-multilingual" model
|
||||
|
||||
### Insights
|
||||
- Using different models for different tasks allows for optimizing performance and cost
|
||||
- Jina's reranker provides a specialized solution for document ranking
|
||||
- The modular approach allows for easy swapping of components and models
|
||||
|
||||
### Next Steps
|
||||
1. Implement the remaining query processing components
|
||||
2. Create the Gradio UI for user interaction
|
||||
3. Test the full system with end-to-end workflows
|
||||
|
||||
## Session: 2025-02-27 (Update 5)
|
||||
|
||||
### Overview
|
||||
Added support for OpenRouter and Groq as LLM providers and configured the system to use Groq for testing.
|
||||
|
||||
### Key Activities
|
||||
1. **Jina Reranker API Integration**:
|
||||
- Updated the `rerank` method in the JinaReranker class to match the expected API request format
|
||||
- Modified the request payload to send an array of plain string documents instead of objects
|
||||
- Enhanced response processing to handle both current and older API response formats
|
||||
- Added detailed logging for API requests and responses for better debugging
|
||||
|
||||
2. **Testing Improvements**:
|
||||
- Created a simplified test script (`test_simple_reranker.py`) to isolate and test the reranker functionality
|
||||
- Updated the main test script to focus on core functionality without complex dependencies
|
||||
- Implemented JSON result saving for better analysis of reranker output
|
||||
- Added proper error handling in tests to provide clear feedback on issues
|
||||
|
||||
3. **Code Quality Enhancements**:
|
||||
- Improved error handling throughout the reranker implementation
|
||||
- Added informative debug messages at key points in the execution flow
|
||||
- Ensured backward compatibility with previous API response formats
|
||||
- Documented the expected request and response structures
|
||||
|
||||
### Insights and Learnings
|
||||
- The Jina Reranker API expects documents as an array of plain strings, not objects with a "text" field
|
||||
- The reranker response format includes a "document" field in the results which may contain either the text directly or an object with a "text" field
|
||||
- Proper error handling and debug output are crucial for diagnosing issues with external API integrations
|
||||
- Isolating components for testing makes debugging much more efficient
|
||||
- The service layer pattern provides a clean separation between the API and the existing sim-search functionality
|
||||
- FastAPI's dependency injection system makes it easy to handle authentication and database sessions
|
||||
- Asynchronous operation is essential for handling long-running tasks like report generation
|
||||
- The layered architecture makes it easier to maintain and extend both components independently
|
||||
|
||||
### Challenges
|
||||
- Adapting to changes in the Jina Reranker API response format
|
||||
- Ensuring backward compatibility with older response formats
|
||||
- Debugging nested API response structures
|
||||
- Managing environment variables and configuration consistently across test scripts
|
||||
- Ensuring proper integration with the existing sim-search functionality
|
||||
- Handling asynchronous operations throughout the API
|
||||
- Managing database sessions and transactions
|
||||
- Implementing proper error handling and logging
|
||||
|
||||
### Next Steps
|
||||
1. **Expand Testing**: Develop more comprehensive test cases for the reranker with diverse document types
|
||||
2. **Integration**: Ensure the reranker is properly integrated with the result collector for end-to-end functionality
|
||||
3. **Documentation**: Update API documentation to reflect the latest changes to the reranker implementation
|
||||
4. **UI Integration**: Add reranker configuration options to the Gradio interface
|
||||
1. Test the FastAPI implementation to ensure it works correctly with the existing sim-search functionality
|
||||
2. Create a React frontend to consume the FastAPI backend
|
||||
3. Implement user management in the frontend
|
||||
4. Add search history and report management to the frontend
|
||||
5. Implement real-time progress tracking for report generation in the frontend
|
||||
6. Add visualization components for reports in the frontend
|
||||
7. Run comprehensive tests to ensure all functionality works with the new API
|
||||
8. Update any remaining documentation to reflect the new API
|
||||
9. Consider adding more API endpoints for additional functionality
|
||||
|
||||
## Session: 2025-02-27 - Report Generation Module Planning
|
||||
## Session: 2025-03-19 - Fixed Gradio UI Bug with List Object in Markdown Component
|
||||
|
||||
### Overview
|
||||
In this session, we focused on planning the Report Generation module, designing a comprehensive implementation approach, and making key decisions about document scraping, storage, and processing.
|
||||
Fixed a critical bug in the Gradio UI where a list object was being passed to a Markdown component, causing an AttributeError when the `expandtabs()` method was called on the list.
|
||||
|
||||
### Key Activities
|
||||
1. **Designed a Phased Implementation Plan**:
|
||||
- Created a four-phase implementation plan for the Report Generation module
|
||||
- Phase 1: Document Scraping and Storage
|
||||
- Phase 2: Document Prioritization and Chunking
|
||||
- Phase 3: Report Generation
|
||||
- Phase 4: Advanced Features
|
||||
- Documented the plan in the memory bank for future reference
|
||||
1. **Identified the Root Cause**:
|
||||
- The error occurred in the Gradio interface, specifically in the Markdown component's postprocess method
|
||||
- The error message was: `AttributeError: 'list' object has no attribute 'expandtabs'`
|
||||
- The issue was in the `_delete_selected_reports` and `refresh_reports_list` functions, which were returning three values (reports_data, choices, status_message), but the click handlers were only expecting two outputs (reports_checkbox_group, status_message)
|
||||
- This caused the list to be passed to the Markdown component, which expected a string
|
||||
|
||||
2. **Made Key Design Decisions**:
|
||||
- Decided to use Jina Reader for web scraping due to its clean content extraction capabilities
|
||||
- Chose SQLite for document storage to ensure persistence and efficient querying
|
||||
- Designed a database schema with Documents and Metadata tables
|
||||
- Planned a token budget management system to handle context window limitations
|
||||
- Decided on a map-reduce approach for processing large document collections
|
||||
2. **Implemented Fixes**:
|
||||
- Updated the click handlers for the delete button and refresh button to handle all three outputs
|
||||
- Added the reports_checkbox_group component twice in the outputs list to match the three return values
|
||||
- This ensured that the status_message (a string) was correctly passed to the Markdown component
|
||||
- Tested the fix by running the UI and verifying that the error no longer occurs
|
||||
|
||||
3. **Addressed Context Window Limitations**:
|
||||
- Evaluated Groq's Llama 3.3 70B Versatile model's 128K context window
|
||||
- Designed document prioritization strategies based on relevance scores
|
||||
- Planned chunking strategies for handling long documents
|
||||
- Considered alternative models with larger context windows for future implementation
|
||||
|
||||
4. **Updated Documentation**:
|
||||
- Added the implementation plan to the memory bank
|
||||
- Updated the decision log with rationale for key decisions
|
||||
- Revised the current focus to reflect the new implementation priorities
|
||||
- Added a new session log entry to document the planning process
|
||||
3. **Verified the Solution**:
|
||||
- Confirmed that the UI now works correctly without any errors
|
||||
- Tested various operations (deleting reports, refreshing the list) to ensure they work as expected
|
||||
- Verified that the status messages are displayed correctly in the UI
|
||||
|
||||
### Insights
|
||||
- A phased implementation approach allows for incremental development and testing
|
||||
- SQLite provides a good balance of simplicity and functionality for document storage
|
||||
- Jina Reader integrates well with our existing Jina components (embeddings, reranker)
|
||||
- The map-reduce pattern enables processing of unlimited document collections despite context window limitations
|
||||
- Document prioritization is crucial for ensuring the most relevant content is included in reports
|
||||
- Gradio's component handling requires careful matching between function return values and output components
|
||||
- When a function returns more values than there are output components, Gradio will try to pass the extra values to the last component
|
||||
- In this case, the list was being passed to the Markdown component, which expected a string
|
||||
- Adding the same component multiple times in the outputs list is a valid solution to handle multiple return values
|
||||
|
||||
### Challenges
|
||||
- Managing the 128K context window limitation with potentially large document collections
|
||||
- Balancing between document coverage and report quality
|
||||
- Ensuring efficient web scraping without overwhelming target websites
|
||||
- Designing a flexible architecture that can accommodate different models and approaches
|
||||
- Identifying the root cause of the error required careful analysis of the error message and the code
|
||||
- Understanding how Gradio handles function return values and output components
|
||||
- Ensuring that the fix doesn't introduce new issues
|
||||
|
||||
### Next Steps
|
||||
1. Begin implementing Phase 1 of the Report Generation module:
|
||||
- Set up the SQLite database with the designed schema
|
||||
- Implement the Jina Reader integration for web scraping
|
||||
- Create the document processing pipeline
|
||||
- Develop URL validation and normalization functionality
|
||||
- Add caching and deduplication for scraped content
|
||||
1. Consider adding more comprehensive error handling in the UI components
|
||||
2. Review other similar functions to ensure they don't have the same issue
|
||||
3. Add more detailed logging to help diagnose similar issues in the future
|
||||
4. Consider adding unit tests for the UI components to catch similar issues earlier
|
||||
|
||||
2. Plan for Phase 2 implementation:
|
||||
- Design the token budget management system
|
||||
- Develop document prioritization algorithms
|
||||
- Create chunking strategies for long documents
|
||||
|
||||
## Session: 2025-02-27 - Report Generation Module Implementation (Phase 1)
|
||||
## Session: 2025-03-19 - Model Provider Selection Fix in Report Generation
|
||||
|
||||
### Overview
|
||||
In this session, we implemented Phase 1 of the Report Generation module, focusing on document scraping and SQLite storage. We created the necessary components for scraping web pages, storing their content in a SQLite database, and retrieving documents for report generation.
|
||||
Fixed an issue with model provider selection in the report generation process, ensuring that the provider specified in the config.yaml file is correctly used throughout the report generation pipeline.
|
||||
|
||||
### Key Activities
|
||||
1. **Created Database Manager**:
|
||||
- Implemented a SQLite database manager with tables for documents and metadata
|
||||
- Added full CRUD operations for documents
|
||||
- Implemented transaction handling for data integrity
|
||||
- Created methods for document search and retrieval
|
||||
- Used aiosqlite for asynchronous database operations
|
||||
1. Identified the root cause of the model provider selection issue:
|
||||
- The model selected in the UI was correctly passed to the report generator
|
||||
- However, the provider information was not being properly respected
|
||||
- The code was trying to guess the provider based on the model name instead of using the provider from the config
|
||||
|
||||
2. **Implemented Document Scraper**:
|
||||
- Created a document scraper with Jina Reader API integration
|
||||
- Added fallback mechanism using BeautifulSoup for when Jina API fails
|
||||
- Implemented URL validation and normalization
|
||||
- Added content conversion to Markdown format
|
||||
- Implemented token counting using tiktoken
|
||||
- Created metadata extraction from HTML content
|
||||
- Added document deduplication using content hashing
|
||||
2. Implemented fixes to ensure proper provider selection:
|
||||
- Modified the `generate_completion` method in `ReportSynthesizer` to use the provider from the config file
|
||||
- Removed code that was trying to guess the provider based on the model name
|
||||
- Added proper formatting for different providers (Gemini, Groq, Anthropic, OpenAI)
|
||||
- Enhanced model parameter formatting to handle provider-specific requirements
|
||||
|
||||
3. **Developed Report Generator Base**:
|
||||
- Created the basic structure for the report generation process
|
||||
- Implemented methods to process search results by scraping URLs
|
||||
- Integrated with the database manager and document scraper
|
||||
- Set up the foundation for future phases
|
||||
|
||||
4. **Created Test Script**:
|
||||
- Developed a test script to verify functionality
|
||||
- Tested document scraping, storage, and retrieval
|
||||
- Verified search functionality within the database
|
||||
- Ensured proper error handling and fallback mechanisms
|
||||
3. Added detailed logging:
|
||||
- Added logging of the provider and model being used at key points in the process
|
||||
- Added logging of the final model parameter and provider being used
|
||||
- This helps with debugging any future issues with model selection
|
||||
|
||||
### Insights
|
||||
- The fallback mechanism for document scraping is crucial, as the Jina Reader API may not always be available or may fail for certain URLs
|
||||
- Asynchronous processing significantly improves performance when scraping multiple URLs
|
||||
- Content hashing is an effective way to prevent duplicate documents in the database
|
||||
- Storing metadata separately from document content provides flexibility for future enhancements
|
||||
- The SQLite database provides a good balance of simplicity and functionality for document storage
|
||||
- Different LLM providers have different requirements for model parameter formatting
|
||||
- For Gemini models, LiteLLM requires setting `custom_llm_provider` to 'vertex_ai'
|
||||
- Detailed logging is essential for tracking model and provider usage in complex systems
|
||||
|
||||
### Challenges
|
||||
- Handling different HTML structures across websites for metadata extraction
|
||||
- Managing asynchronous operations and error handling
|
||||
- Ensuring proper transaction handling for database operations
|
||||
- Balancing between clean content extraction and preserving important information
|
||||
|
||||
### Next Steps
|
||||
1. **Integration with Search Execution**:
|
||||
- Connect the report generation module to the search execution pipeline
|
||||
- Implement automatic processing of search results
|
||||
|
||||
2. **Begin Phase 2 Implementation**:
|
||||
- Develop document prioritization based on relevance scores
|
||||
- Implement chunking strategies for long documents
|
||||
- Create token budget management system
|
||||
|
||||
3. **Testing and Refinement**:
|
||||
- Create more comprehensive tests for edge cases
|
||||
- Refine error handling and logging
|
||||
- Optimize performance for large numbers of documents
|
||||
|
||||
## Session: 2025-02-27 (Update)
|
||||
|
||||
### Overview
|
||||
Implemented Phase 3 of the Report Generation module, focusing on report synthesis using LLMs with a map-reduce approach.
|
||||
|
||||
### Key Activities
|
||||
1. **Created Report Synthesis Module**:
|
||||
- Implemented the `ReportSynthesizer` class for generating reports using Groq's Llama 3.3 70B model
|
||||
- Created a map-reduce approach for processing document chunks:
|
||||
- Map phase: Extract key information from individual chunks
|
||||
- Reduce phase: Synthesize extracted information into a coherent report
|
||||
- Added support for different query types (factual, exploratory, comparative)
|
||||
- Implemented automatic query type detection based on query text
|
||||
- Added citation generation and reference management
|
||||
|
||||
2. **Updated Report Generator**:
|
||||
- Integrated the new report synthesis module with the existing report generator
|
||||
- Replaced the placeholder report generation with the new LLM-based synthesis
|
||||
- Added proper error handling and logging throughout the process
|
||||
|
||||
3. **Created Test Scripts**:
|
||||
- Developed a dedicated test script for the report synthesis functionality
|
||||
- Implemented tests with both sample data and real URLs
|
||||
- Added support for mock data to avoid API dependencies during testing
|
||||
- Verified end-to-end functionality from document scraping to report generation
|
||||
|
||||
4. **Fixed LLM Integration Issues**:
|
||||
- Corrected the model name format for Groq provider by prefixing it with 'groq/'
|
||||
- Improved error handling for API failures
|
||||
- Added proper logging for the map-reduce process
|
||||
|
||||
### Insights
|
||||
- The map-reduce approach is effective for processing large amounts of document data
|
||||
- Different query types benefit from specialized report templates
|
||||
- Groq's Llama 3.3 70B model produces high-quality reports with good coherence and factual accuracy
|
||||
- Proper citation management is essential for creating trustworthy reports
|
||||
- Automatic query type detection works well for common query patterns
|
||||
|
||||
### Challenges
|
||||
- Managing API errors and rate limits with external LLM providers
|
||||
- Ensuring consistent formatting across different report sections
|
||||
- Balancing between report comprehensiveness and token usage
|
||||
- Handling edge cases where document chunks contain irrelevant information
|
||||
|
||||
### Next Steps
|
||||
1. Implement support for alternative models with larger context windows
|
||||
2. Develop progressive report generation for very large research tasks
|
||||
3. Create visualization components for data mentioned in reports
|
||||
4. Add interactive elements to the generated reports
|
||||
5. Implement report versioning and comparison
|
||||
|
||||
## Session: 2025-02-27 (Update 2)
|
||||
|
||||
### Overview
|
||||
Successfully tested the end-to-end query to report pipeline with a specific query about the environmental and economic impact of electric vehicles, and fixed an issue with the Jina reranker integration.
|
||||
|
||||
### Key Activities
|
||||
1. **Fixed Jina Reranker Integration**:
|
||||
- Corrected the import statement in query_to_report.py to use the proper function name (get_jina_reranker)
|
||||
- Updated the reranker call to properly format the results for the JinaReranker
|
||||
- Implemented proper extraction of text from search results for reranking
|
||||
- Added mapping of reranked indices back to the original results
|
||||
|
||||
2. **Created EV Query Test Script**:
|
||||
- Developed a dedicated test script (test_ev_query.py) for testing the pipeline with a query about electric vehicles
|
||||
- Configured the script to use 7 results per search engine for a comprehensive report
|
||||
- Added proper error handling and result display
|
||||
|
||||
3. **Tested End-to-End Pipeline**:
|
||||
- Successfully executed the full query to report workflow
|
||||
- Verified that all components (query processor, search executor, reranker, report generator) work together seamlessly
|
||||
- Generated a comprehensive report on the environmental and economic impact of electric vehicles
|
||||
|
||||
4. **Identified Report Detail Configuration Options**:
|
||||
- Documented multiple ways to adjust the level of detail in generated reports
|
||||
- Identified parameters that can be modified to control report comprehensiveness
|
||||
- Created a plan for implementing customizable report detail levels
|
||||
|
||||
### Insights
|
||||
- The end-to-end pipeline successfully connects all major components of the system
|
||||
- The Jina reranker significantly improves the relevance of search results for report generation
|
||||
- The map-reduce approach effectively processes document chunks into a coherent report
|
||||
- Some document sources (like ScienceDirect and ResearchGate) may require special handling due to access restrictions
|
||||
|
||||
### Challenges
|
||||
- Handling API errors and access restrictions for certain document sources
|
||||
- Ensuring proper formatting of data between different components
|
||||
- Managing the processing of a large number of document chunks efficiently
|
||||
|
||||
### Next Steps
|
||||
1. **Implement Customizable Report Detail Levels**:
|
||||
- Develop a system to allow users to select different levels of detail for generated reports
|
||||
- Integrate the customizable detail levels into the report generator
|
||||
- Test the new feature with various query types
|
||||
|
||||
2. **Add Support for Alternative Models**:
|
||||
- Research and implement support for alternative models with larger context windows
|
||||
- Test the new models with the report generation pipeline
|
||||
|
||||
3. **Develop Progressive Report Generation**:
|
||||
- Design and implement a system for progressive report generation
|
||||
- Test the new feature with very large research tasks
|
||||
|
||||
4. **Create Visualization Components**:
|
||||
- Develop visualization components for data mentioned in reports
|
||||
- Integrate the visualization components into the report generator
|
||||
|
||||
5. **Add Interactive Elements**:
|
||||
- Develop interactive elements for the generated reports
|
||||
- Integrate the interactive elements into the report generator
|
||||
|
||||
## Session: 2025-02-28
|
||||
|
||||
### Overview
|
||||
Implemented customizable report detail levels for the Report Generation Module, allowing users to select different levels of detail for generated reports.
|
||||
|
||||
### Key Activities
|
||||
1. **Created Report Detail Levels Module**:
|
||||
- Implemented a new module `report_detail_levels.py` with an enum for detail levels (Brief, Standard, Detailed, Comprehensive)
|
||||
- Created a `ReportDetailLevelManager` class to manage detail level configurations
|
||||
- Defined specific parameters for each detail level (num_results, token_budget, chunk_size, overlap_size, model)
|
||||
- Added methods to validate and retrieve detail level configurations
|
||||
|
||||
2. **Updated Report Synthesis Module**:
|
||||
- Modified the `ReportSynthesizer` class to accept and use detail level parameters
|
||||
- Updated synthesis templates to adapt based on the selected detail level
|
||||
- Adjusted the map-reduce process to handle different levels of detail
|
||||
- Implemented model selection based on detail level requirements
|
||||
|
||||
3. **Enhanced Report Generator**:
|
||||
- Added methods to set and get detail levels in the `ReportGenerator` class
|
||||
- Updated the document preparation process to use detail level configurations
|
||||
- Modified the report generation workflow to incorporate detail level settings
|
||||
- Implemented validation for detail level parameters
|
||||
|
||||
4. **Updated Query to Report Script**:
|
||||
- Added command-line arguments for detail level selection
|
||||
- Implemented a `--list-detail-levels` option to display available options
|
||||
- Updated the main workflow to pass detail level parameters to the report generator
|
||||
- Added documentation for the new parameters
|
||||
|
||||
5. **Created Test Scripts**:
|
||||
- Updated `test_ev_query.py` to support detail level selection
|
||||
- Created a new `test_detail_levels.py` script to generate reports with all detail levels for comparison
|
||||
- Added metrics collection (timing, report size, word count) for comparison
|
||||
|
||||
### Insights
|
||||
- Different detail levels significantly affect report length, depth, and generation time
|
||||
- The brief level is useful for quick summaries, while comprehensive provides exhaustive information
|
||||
- Using different models for different detail levels offers a good balance between speed and quality
|
||||
- Configuring multiple parameters (num_results, token_budget, etc.) together creates a coherent detail level experience
|
||||
|
||||
### Challenges
|
||||
- Ensuring that the templates produce appropriate output for each detail level
|
||||
- Balancing between speed and quality for different detail levels
|
||||
- Managing token budgets effectively across different detail levels
|
||||
- Understanding the specific requirements for each provider in LiteLLM
|
||||
- Ensuring backward compatibility with existing code
|
||||
- Balancing between automatic provider detection and respecting explicit configuration
|
||||
|
||||
### Next Steps
|
||||
1. Conduct thorough testing of the detail level features with various query types
|
||||
2. Gather user feedback on the quality and usefulness of reports at different detail levels
|
||||
3. Refine the detail level configurations based on testing and feedback
|
||||
4. Implement progressive report generation for very large research tasks
|
||||
5. Develop visualization components for data mentioned in reports
|
||||
1. ✅ Test the fix with various models and providers to ensure it works in all scenarios
|
||||
2. ✅ Implement comprehensive unit tests for provider selection stability
|
||||
3. Update documentation to clarify how model and provider selection works
|
||||
|
||||
## Session: 2025-02-28 - Enhanced Report Detail Levels
|
||||
### Testing Results
|
||||
Created and executed a comprehensive test script (`report_synthesis_test.py`) to verify the model provider selection fix:
|
||||
|
||||
1. **Groq Provider (llama-3.3-70b-versatile)**:
|
||||
- Successfully initialized with provider "groq"
|
||||
- Completion parameters correctly showed: `'model': 'groq/llama-3.3-70b-versatile'`
|
||||
- LiteLLM logs confirmed: `LiteLLM completion() model= llama-3.3-70b-versatile; provider = groq`
|
||||
|
||||
2. **Gemini Provider (gemini-2.0-flash)**:
|
||||
- Successfully initialized with provider "gemini"
|
||||
- Completion parameters correctly showed: `'model': 'gemini-2.0-flash'` with `'custom_llm_provider': 'vertex_ai'`
|
||||
- Confirmed our fix for Gemini models using the correct vertex_ai provider
|
||||
|
||||
3. **Anthropic Provider (claude-3-opus-20240229)**:
|
||||
- Successfully initialized with provider "anthropic"
|
||||
- Completion parameters correctly showed: `'model': 'claude-3-opus-20240229'` with `'custom_llm_provider': 'anthropic'`
|
||||
- Received a successful response from Claude
|
||||
|
||||
4. **OpenAI Provider (gpt-4-turbo)**:
|
||||
- Successfully initialized with provider "openai"
|
||||
- Completion parameters correctly showed: `'model': 'gpt-4-turbo'` with `'custom_llm_provider': 'openai'`
|
||||
- Received a successful response from GPT-4
|
||||
|
||||
The test confirmed that our fix is working as expected, with the system now correctly:
|
||||
1. Using the provider specified in the config.yaml file
|
||||
2. Formatting the model parameters appropriately for each provider
|
||||
3. Logging the final model parameter and provider for better debugging
|
||||
|
||||
## Session: 2025-03-19 - Provider Selection Stability Testing
|
||||
|
||||
### Overview
|
||||
In this session, we enhanced the report detail levels to focus more on analytical depth rather than just adding additional sections. We improved the document chunk processing to extract more meaningful information from each chunk for detailed and comprehensive reports.
|
||||
Implemented comprehensive tests to ensure provider selection remains stable across multiple initializations, model switches, and direct configuration changes.
|
||||
|
||||
### Key Activities
|
||||
1. **Enhanced Template Modifiers for Detailed and Comprehensive Reports**:
|
||||
- Rewrote the template modifiers to focus on analytical depth, evidence density, and perspective diversity
|
||||
- Added explicit instructions to prioritize depth over breadth
|
||||
- Emphasized multi-layered analysis, causal relationships, and interconnections
|
||||
- Added instructions for exploring second and third-order effects
|
||||
1. Designed and implemented a test suite for provider selection stability:
|
||||
- Created `test_provider_selection_stability` function in `report_synthesis_test.py`
|
||||
- Implemented three main test scenarios to verify provider stability
|
||||
- Fixed issues with the test approach to properly use the global config singleton
|
||||
|
||||
2. **Improved Document Chunk Processing**:
|
||||
- Created a new `_get_extraction_prompt` method that provides different extraction prompts based on detail level
|
||||
- For DETAILED reports: Added focus on underlying principles, causal relationships, and different perspectives
|
||||
- For COMPREHENSIVE reports: Added focus on multi-layered analysis, complex causal networks, and theoretical frameworks
|
||||
- Modified the `map_document_chunks` method to pass the detail level parameter
|
||||
2. Test 1: Stability across multiple initializations with the same model
|
||||
- Verified that multiple synthesizers created with the same model consistently use the same provider
|
||||
- Ensured that provider selection is deterministic and not affected by initialization order
|
||||
|
||||
3. **Enhanced MapReduce Approach**:
|
||||
- Updated the map phase to use detail-level-specific extraction prompts
|
||||
- Ensured the detail level parameter is passed throughout the process
|
||||
- Maintained the efficient processing of document chunks while improving the quality of extraction
|
||||
3. Test 2: Stability when switching between models
|
||||
- Tested switching between different models (llama, gemini, claude, gpt) multiple times
|
||||
- Verified that each model consistently selects the appropriate provider based on configuration
|
||||
- Confirmed that switching back and forth between models maintains correct provider selection
|
||||
|
||||
4. Test 3: Stability with direct configuration changes
|
||||
- Tested the system's response to direct changes in the configuration
|
||||
- Modified the global config singleton to change a model's provider
|
||||
- Verified that new synthesizer instances correctly reflect the updated provider
|
||||
- Implemented proper cleanup to restore the original config state after testing
|
||||
|
||||
### Insights
|
||||
- The MapReduce approach is well-suited for LLM-based report generation, allowing processing of more information than would fit in a single context window
|
||||
- Different extraction prompts for different detail levels significantly affect the quality and depth of the extracted information
|
||||
- Focusing on analytical depth rather than additional sections provides more value to the end user
|
||||
- The enhanced prompts guide the LLM to provide deeper analysis of causal relationships, underlying mechanisms, and interconnections
|
||||
- The `ReportSynthesizer` class correctly uses the global config singleton for provider selection
|
||||
- Provider selection remains stable across multiple initializations with the same model
|
||||
- Provider selection correctly adapts when switching between different models
|
||||
- Provider selection properly responds to direct changes in the configuration
|
||||
- Using a try/finally block for config modifications ensures proper cleanup after tests
|
||||
|
||||
### Challenges
|
||||
- Balancing between depth and breadth in detailed reports
|
||||
- Ensuring that the extraction prompts extract the most relevant information for each detail level
|
||||
- Managing the increased processing time for detailed and comprehensive reports with enhanced extraction
|
||||
- Initial approach using a custom `TestSynthesizer` class didn't work as expected
|
||||
- The custom class was not correctly inheriting the config instance
|
||||
- Switched to directly modifying the global config singleton for more accurate testing
|
||||
- Needed to ensure proper cleanup to avoid side effects on other tests
|
||||
|
||||
### Next Steps
|
||||
1. Conduct thorough testing of the enhanced detail level features with various query types
|
||||
2. Compare the analytical depth and quality of reports generated with the new prompts
|
||||
3. Gather user feedback on the improved reports at different detail levels
|
||||
4. Explore parallel processing for the map phase to reduce overall report generation time
|
||||
5. Further refine the detail level configurations based on testing and feedback
|
||||
1. Consider adding more comprehensive tests for edge cases (e.g., invalid providers)
|
||||
2. Add tests for provider fallback mechanisms when specified providers are unavailable
|
||||
3. Document the provider selection process in the codebase for future reference
|
||||
|
||||
## Session: 2025-02-28 - Gradio UI Enhancements and Future Planning
|
||||
## Session: 2025-03-20 - Enhanced Provider Selection Stability Testing
|
||||
|
||||
### Overview
|
||||
In this session, we fixed issues in the Gradio UI for report generation and planned future enhancements to improve search quality and user experience.
|
||||
|
||||
Expanded the provider selection stability tests to include additional scenarios such as fallback mechanisms, edge cases with invalid providers, provider selection when using singleton vs. creating new instances, and stability after config reload.
|
||||
|
||||
### Key Activities
|
||||
1. **Fixed Gradio UI for Report Generation**:
|
||||
- Updated the `generate_report` method in the Gradio UI to properly process queries and generate structured queries
|
||||
- Integrated the `QueryProcessor` to create structured queries from user input
|
||||
- Fixed method calls and parameter passing to the `execute_search` method
|
||||
- Implemented functionality to process `<thinking>` tags in the generated report
|
||||
- Added support for custom model selection in the UI
|
||||
- Updated the interfaces documentation to include ReportGenerator and ReportDetailLevelManager interfaces
|
||||
|
||||
2. **Planned Future Enhancements**:
|
||||
- **Multiple Query Variation Generation**:
|
||||
- Designed an approach to generate several similar queries with different keywords for better search coverage
|
||||
- Planned modifications to the QueryProcessor and SearchExecutor to handle multiple queries
|
||||
- Estimated this as a moderate difficulty task (3-4 days of work)
|
||||
|
||||
- **Threshold-Based Reranking with Larger Document Sets**:
|
||||
- Developed a plan to process more initial documents and use reranking to select the most relevant ones
|
||||
- Designed new detail level configuration parameters for initial and final result counts
|
||||
- Estimated this as an easy to moderate difficulty task (2-3 days of work)
|
||||
|
||||
- **UI Progress Indicators**:
|
||||
- Identified the need for chunk processing progress indicators in the UI
|
||||
- Planned modifications to report_synthesis.py to add logging during document processing
|
||||
- Estimated this as a simple enhancement (15-30 minutes of work)
|
||||
1. Enhanced the existing provider selection stability tests with additional test cases:
|
||||
- Added Test 4: Provider selection when using singleton vs. creating new instances
|
||||
- Added Test 5: Edge case with invalid provider
|
||||
- Added Test 6: Provider fallback mechanism
|
||||
- Added a new test function: `test_provider_selection_after_config_reload`
|
||||
|
||||
2. Test 4: Provider selection when using singleton vs. creating new instances
|
||||
- Verified that the singleton instance and a new instance with the same model use the same provider
|
||||
- Confirmed that the `get_report_synthesizer` function correctly handles model changes
|
||||
- Ensured consistent provider selection regardless of how the synthesizer is instantiated
|
||||
|
||||
3. Test 5: Edge case with invalid provider
|
||||
- Tested how the system handles models with invalid providers
|
||||
- Verified that the invalid provider is preserved in the configuration
|
||||
- Confirmed that the system doesn't crash when encountering an invalid provider
|
||||
- Validated that error logging is appropriate for debugging
|
||||
|
||||
4. Test 6: Provider fallback mechanism
|
||||
- Tested models with no explicit provider specified
|
||||
- Verified that the system correctly infers a provider based on the model name
|
||||
- Confirmed that the default fallback to groq works as expected
|
||||
|
||||
5. Test for provider selection after config reload
|
||||
- Simulated a config reload by creating a new Config instance
|
||||
- Verified that provider selection remains stable after config reload
|
||||
- Ensured proper cleanup of global state after testing
|
||||
|
||||
### Insights
|
||||
- The modular architecture of the system makes it easy to extend with new features
|
||||
- Providing progress indicators during report generation would significantly improve user experience
|
||||
- Generating multiple query variations could substantially improve search coverage and result quality
|
||||
- Using a two-stage approach (fetch more, then filter) for document retrieval would likely improve report quality
|
||||
|
||||
- The provider selection mechanism is robust across different instantiation methods
|
||||
- The system preserves invalid providers in the configuration, which is important for error handling and debugging
|
||||
- The fallback mechanism works correctly for models with no explicit provider
|
||||
- Provider selection remains stable even after config reload
|
||||
- Proper cleanup of global state is essential for preventing test interference
|
||||
|
||||
### Challenges
|
||||
- Balancing between fetching enough documents for comprehensive coverage and maintaining performance
|
||||
- Ensuring proper deduplication when using multiple query variations
|
||||
- Managing the increased API usage that would result from processing more queries and documents
|
||||
|
||||
- Simulating config reload required careful manipulation of the global config singleton
|
||||
- Testing invalid providers required handling expected errors without crashing the tests
|
||||
- Ensuring proper cleanup of global state after each test to prevent side effects
|
||||
|
||||
### Next Steps
|
||||
1. Implement the chunk processing progress indicators as a quick win
|
||||
2. Begin work on the multiple query variation generation feature
|
||||
3. Test the current implementation with various query types to identify any remaining issues
|
||||
4. Update the documentation to reflect the new features and future plans
|
||||
|
||||
## Session: 2025-03-12 - Query Type Selection in Gradio UI
|
||||
|
||||
### Overview
|
||||
In this session, we enhanced the Gradio UI by adding a query type selection dropdown, allowing users to explicitly select the query type (factual, exploratory, comparative) instead of relying on automatic detection.
|
||||
|
||||
### Key Activities
|
||||
1. **Added Query Type Selection to Gradio UI**:
|
||||
- Added a dropdown menu for query type selection in the "Generate Report" tab
|
||||
- Included options for "auto-detect", "factual", "exploratory", and "comparative"
|
||||
- Added descriptive tooltips explaining each query type
|
||||
- Set "auto-detect" as the default option
|
||||
|
||||
2. **Updated Report Generation Logic**:
|
||||
- Modified the `generate_report` method in the `GradioInterface` class to handle the new query_type parameter
|
||||
- Updated the report button click handler to pass the query type to the generate_report method
|
||||
- Added logging to show when a user-selected query type is being used
|
||||
|
||||
3. **Enhanced Report Generator**:
|
||||
- Updated the `generate_report` method in the `ReportGenerator` class to accept a query_type parameter
|
||||
- Modified the report synthesizer calls to pass the query_type parameter
|
||||
- Added logging to track query type usage
|
||||
|
||||
4. **Added Documentation**:
|
||||
- Added a "Query Types" section to the Gradio UI explaining each query type
|
||||
- Included examples of when to use each query type
|
||||
- Updated code comments to explain the query type parameter
|
||||
|
||||
### Insights
|
||||
- Explicit query type selection gives users more control over the report generation process
|
||||
- Different query types benefit from specialized report templates and structures
|
||||
- The auto-detect option provides convenience while still allowing manual override
|
||||
- Clear documentation helps users understand when to use each query type
|
||||
|
||||
### Challenges
|
||||
- Ensuring backward compatibility with existing code
|
||||
- Maintaining the auto-detect functionality while adding manual selection
|
||||
- Passing the query type parameter through multiple layers of the application
|
||||
- Providing clear explanations of query types for users
|
||||
|
||||
### Next Steps
|
||||
1. Test the query type selection with various queries to ensure it works correctly
|
||||
2. Gather user feedback on the usefulness of manual query type selection
|
||||
3. Consider adding more specialized templates for specific query types
|
||||
4. Explore adding query type detection confidence scores to help users decide when to override
|
||||
5. Add examples of each query type to help users understand the differences
|
||||
|
||||
## Session: 2025-03-12 - Fixed Query Type Parameter Bug
|
||||
|
||||
### Overview
|
||||
Fixed a bug in the report generation process where the `query_type` parameter was not properly handled, causing an error when it was `None`.
|
||||
|
||||
### Key Activities
|
||||
1. **Fixed NoneType Error in Report Synthesis**:
|
||||
- Added a null check in the `_get_extraction_prompt` method in `report_synthesis.py`
|
||||
- Modified the condition that checks for comparative queries to handle the case where `query_type` is `None`
|
||||
- Ensured the method works correctly regardless of whether a query type is explicitly provided
|
||||
|
||||
2. **Root Cause Analysis**:
|
||||
- Identified that the error occurred when the `query_type` parameter was `None` and the code tried to call `.lower()` on it
|
||||
- Traced the issue through the call chain from the UI to the report generator to the report synthesizer
|
||||
- Confirmed that the fix addresses the specific error message: `'NoneType' object has no attribute 'lower'`
|
||||
|
||||
### Insights
|
||||
- Proper null checking is essential when working with optional parameters that are passed through multiple layers
|
||||
- The error occurred in the report synthesis module but was triggered by the UI's query type selection feature
|
||||
- The fix maintains backward compatibility while ensuring the new query type selection feature works correctly
|
||||
|
||||
### Next Steps
|
||||
1. Test the fix with various query types to ensure it works correctly
|
||||
2. Consider adding similar null checks in other parts of the code that handle the query_type parameter
|
||||
3. Add more comprehensive error handling throughout the report generation process
|
||||
4. Update the test suite to include tests for null query_type values
|
||||
|
||||
## Session: 2025-03-12 - Fixed Template Retrieval for Null Query Type
|
||||
|
||||
### Overview
|
||||
Fixed a second issue in the report generation process where the template retrieval was failing when the `query_type` parameter was `None`.
|
||||
|
||||
### Key Activities
|
||||
1. **Fixed Template Retrieval for Null Query Type**:
|
||||
- Updated the `_get_template_from_strings` method in `report_synthesis.py` to handle `None` query_type
|
||||
- Added a default value of "exploratory" when query_type is `None`
|
||||
- Modified the method signature to explicitly indicate that query_type_str can be `None`
|
||||
- Added logging to indicate when the default query type is being used
|
||||
|
||||
2. **Root Cause Analysis**:
|
||||
- Identified that the error occurred when trying to convert `None` to a `QueryType` enum value
|
||||
- The error message was: "No template found for None standard" and "None is not a valid QueryType"
|
||||
- The issue was in the template retrieval process which is used by both standard and progressive report synthesis
|
||||
|
||||
### Insights
|
||||
- When fixing one issue with optional parameters, it's important to check for similar issues in related code paths
|
||||
- Providing sensible defaults for optional parameters helps maintain robustness
|
||||
- Proper error handling and logging helps diagnose issues in complex systems with multiple layers
|
||||
|
||||
### Next Steps
|
||||
1. Test the fix with comprehensive reports to ensure it works correctly
|
||||
2. Consider adding similar default values for other optional parameters
|
||||
3. Review the codebase for other potential null reference issues
|
||||
4. Update documentation to clarify the behavior when optional parameters are not provided
|
||||
1. Document the provider selection process in the codebase for future reference
|
||||
2. Consider adding tests for more complex scenarios like provider failover
|
||||
3. Explore adding a provider validation step during initialization
|
||||
4. Add more detailed error messages for invalid provider configurations
|
||||
5. Consider implementing a provider capability check to ensure the selected provider can handle the requested model
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
# LLM-Based Query Classification
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the implementation of LLM-based query domain classification in the sim-search project, replacing the previous keyword-based approach.
|
||||
|
||||
## Motivation
|
||||
|
||||
The previous keyword-based classification had several limitations:
|
||||
- Relied on static lists of keywords that needed constant updating
|
||||
- Could not capture the semantic meaning of queries
|
||||
- Generated false classifications for ambiguous or novel queries
|
||||
- Required significant maintenance to keep keyword lists updated
|
||||
|
||||
## Implementation
|
||||
|
||||
### New Components
|
||||
|
||||
1. **LLM Interface Extension**:
|
||||
- Added `classify_query_domain()` method to `LLMInterface` class
|
||||
- Added `_classify_query_domain_impl()` private implementation method
|
||||
- Configured to use the fast Llama-3.1-8b-instant model by default
|
||||
|
||||
2. **Query Processor Updates**:
|
||||
- Added `_structure_query_with_llm()` method that uses the LLM classification results
|
||||
- Updated `process_query()` to use both query type and domain classification
|
||||
- Retained keyword-based method as a fallback in case of LLM API failures
|
||||
|
||||
3. **Structured Query Enhancements**:
|
||||
- Added new fields to the structured query:
|
||||
- `domain`: Primary domain type (academic, code, current_events, general)
|
||||
- `domain_confidence`: Confidence score for the primary domain
|
||||
- `secondary_domains`: Array of secondary domains with confidence scores
|
||||
- `classification_reasoning`: Explanation of the classification
|
||||
|
||||
4. **Configuration Updates**:
|
||||
- Added `classify_query_domain` to the module-specific model assignments
|
||||
- Using the same Llama-3.1-8b-instant model for domain classification as for other query processing tasks
|
||||
|
||||
5. **Logging and Monitoring**:
|
||||
- Added detailed logging of domain classification results
|
||||
- Log secondary domains with confidence scores
|
||||
- Log the reasoning behind classifications
|
||||
|
||||
6. **Error Handling**:
|
||||
- Added fallback to keyword-based classification if LLM-based classification fails
|
||||
- Implemented robust JSON parsing with fallbacks to default values
|
||||
- Added explicit error messages for troubleshooting
|
||||
|
||||
### Classification Process
|
||||
|
||||
The query domain classification process works as follows:
|
||||
|
||||
1. The query is sent to the LLM with a prompt specifying the four domain types
|
||||
2. The LLM returns a JSON response containing:
|
||||
- Primary domain type with confidence score
|
||||
- Array of secondary domain types with confidence scores
|
||||
- Reasoning for the classification
|
||||
3. The response is parsed and integrated into the structured query
|
||||
4. The `is_academic`, `is_code`, and `is_current_events` flags are set based on:
|
||||
- Primary domain matching the type
|
||||
- Any secondary domain matching the type with confidence above 0.3
|
||||
5. The structured query is then used by downstream components like the search executor
|
||||
|
||||
## Benefits
|
||||
|
||||
The new approach offers several advantages:
|
||||
|
||||
1. **Semantic Understanding**: Captures the meaning and intent of queries rather than just keyword matching
|
||||
2. **Multi-Domain Recognition**: Recognizes when queries span multiple domains with confidence scores
|
||||
3. **Self-Explaining**: Provides reasoning for classifications, aiding debugging and transparency
|
||||
4. **Adaptability**: Automatically adapts to new topics and terminology without code changes
|
||||
5. **Confidence Scoring**: Indicates how confident the system is in its classification
|
||||
|
||||
## Testing and Validation
|
||||
|
||||
A comprehensive test script (`test_domain_classification.py`) was created to:
|
||||
1. Test the raw domain classification function with a variety of queries
|
||||
2. Test the query processor's integration with domain classification
|
||||
3. Compare the LLM-based approach with the previous keyword-based approach
|
||||
|
||||
## Examples
|
||||
|
||||
### Academic Query Example
|
||||
**Query**: "What are the technological, economic, and social implications of large language models in today's society?"
|
||||
|
||||
**LLM Classification**:
|
||||
```json
|
||||
{
|
||||
"primary_type": "academic",
|
||||
"confidence": 0.9,
|
||||
"secondary_types": [
|
||||
{"type": "general", "confidence": 0.4}
|
||||
],
|
||||
"reasoning": "This query is asking about implications of LLMs across multiple domains (technological, economic, and social) which is a scholarly research topic that would be well-addressed by academic sources."
|
||||
}
|
||||
```
|
||||
|
||||
### Code Query Example
|
||||
**Query**: "How do I implement a transformer model in PyTorch for text classification?"
|
||||
|
||||
**LLM Classification**:
|
||||
```json
|
||||
{
|
||||
"primary_type": "code",
|
||||
"confidence": 0.95,
|
||||
"secondary_types": [
|
||||
{"type": "academic", "confidence": 0.4}
|
||||
],
|
||||
"reasoning": "This is primarily a programming question about implementing a specific model in PyTorch, which is a coding framework. It has academic aspects since it relates to machine learning models, but the focus is on implementation."
|
||||
}
|
||||
```
|
||||
|
||||
## Future Improvements
|
||||
|
||||
Potential enhancements for the future:
|
||||
|
||||
1. **Caching**: Add caching for frequently asked or similar queries to reduce API calls
|
||||
2. **Few-Shot Learning**: Add examples in the prompt to improve classification accuracy
|
||||
3. **Expanded Domains**: Consider additional domain categories beyond the current four
|
||||
4. **UI Integration**: Expose classification reasoning in the UI for advanced users
|
||||
5. **Classification Feedback Loop**: Allow users to correct misclassifications to improve the system over time
|
|
@ -199,6 +199,66 @@ class LLMInterface:
|
|||
# Return error message in a user-friendly format
|
||||
return f"I encountered an error while processing your request: {str(e)}"
|
||||
|
||||
async def classify_query_domain(self, query: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Classify a query's domain type (academic, code, current_events, general).
|
||||
|
||||
Args:
|
||||
query: The query to classify
|
||||
|
||||
Returns:
|
||||
Dictionary with query domain type and confidence scores
|
||||
"""
|
||||
# Get the model assigned to this function
|
||||
model_name = self.config.get_module_model('query_processing', 'classify_query_domain')
|
||||
|
||||
# Create a new interface with the assigned model if different from current
|
||||
if model_name != self.model_name:
|
||||
interface = LLMInterface(model_name)
|
||||
return await interface._classify_query_domain_impl(query)
|
||||
|
||||
return await self._classify_query_domain_impl(query)
|
||||
|
||||
async def _classify_query_domain_impl(self, query: str) -> Dict[str, Any]:
|
||||
"""Implementation of query domain classification."""
|
||||
messages = [
|
||||
{"role": "system", "content": """You are an expert query classifier.
|
||||
Analyze the given query and classify it into the following domain types:
|
||||
- academic: Related to scholarly research, scientific studies, academic papers, formal theories, university-level research topics, or scholarly fields of study
|
||||
- code: Related to programming, software development, technical implementation, coding languages, frameworks, or technology implementation questions
|
||||
- current_events: Related to recent news, ongoing developments, time-sensitive information, current politics, breaking stories, or real-time events
|
||||
- general: General information seeking that doesn't fit the above categories
|
||||
|
||||
You may assign multiple types if the query spans several domains.
|
||||
|
||||
Respond with a JSON object containing:
|
||||
{
|
||||
"primary_type": "the most appropriate type",
|
||||
"confidence": 0.X,
|
||||
"secondary_types": [{"type": "another_applicable_type", "confidence": 0.X}, ...],
|
||||
"reasoning": "brief explanation of your classification"
|
||||
}
|
||||
"""},
|
||||
{"role": "user", "content": query}
|
||||
]
|
||||
|
||||
# Generate classification
|
||||
response = await self.generate_completion(messages)
|
||||
|
||||
# Parse JSON response
|
||||
try:
|
||||
classification = json.loads(response)
|
||||
return classification
|
||||
except json.JSONDecodeError:
|
||||
# Fallback to default classification if parsing fails
|
||||
print(f"Error parsing domain classification response: {response}")
|
||||
return {
|
||||
"primary_type": "general",
|
||||
"confidence": 0.5,
|
||||
"secondary_types": [],
|
||||
"reasoning": "Failed to parse classification response"
|
||||
}
|
||||
|
||||
async def classify_query(self, query: str) -> Dict[str, str]:
|
||||
"""
|
||||
Classify a query as factual, exploratory, or comparative.
|
||||
|
|
|
@ -45,15 +45,27 @@ class QueryProcessor:
|
|||
enhanced_query = await self.llm_interface.enhance_query(query)
|
||||
logger.info(f"Enhanced query: {enhanced_query}")
|
||||
|
||||
# Classify the query
|
||||
classification = await self.llm_interface.classify_query(query)
|
||||
logger.info(f"Query classification: {classification}")
|
||||
# Classify the query type (factual, exploratory, comparative)
|
||||
query_type_classification = await self.llm_interface.classify_query(query)
|
||||
logger.info(f"Query type classification: {query_type_classification}")
|
||||
|
||||
# Extract entities from the classification
|
||||
entities = classification.get('entities', [])
|
||||
# Classify the query domain (academic, code, current_events, general)
|
||||
domain_classification = await self.llm_interface.classify_query_domain(query)
|
||||
logger.info(f"Query domain classification: {domain_classification}")
|
||||
|
||||
# Structure the query for downstream modules
|
||||
structured_query = self._structure_query(query, enhanced_query, classification)
|
||||
# Log classification details for monitoring
|
||||
if domain_classification.get('secondary_types'):
|
||||
for sec_type in domain_classification.get('secondary_types'):
|
||||
logger.info(f"Secondary domain: {sec_type['type']} confidence={sec_type['confidence']}")
|
||||
logger.info(f"Classification reasoning: {domain_classification.get('reasoning', 'None provided')}")
|
||||
|
||||
try:
|
||||
# Structure the query using the new classification approach
|
||||
structured_query = self._structure_query_with_llm(query, enhanced_query, query_type_classification, domain_classification)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM domain classification failed: {e}. Falling back to keyword-based classification.")
|
||||
# Fallback to keyword-based approach
|
||||
structured_query = self._structure_query(query, enhanced_query, query_type_classification)
|
||||
|
||||
# Decompose the query into sub-questions (if complex enough)
|
||||
structured_query = await self.query_decomposer.decompose_query(query, structured_query)
|
||||
|
@ -66,10 +78,68 @@ class QueryProcessor:
|
|||
|
||||
return structured_query
|
||||
|
||||
def _structure_query_with_llm(self, original_query: str, enhanced_query: str,
|
||||
type_classification: Dict[str, Any],
|
||||
domain_classification: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Structure a query using LLM classification results.
|
||||
|
||||
Args:
|
||||
original_query: The original user query
|
||||
enhanced_query: The enhanced query
|
||||
type_classification: Classification of query type (factual, exploratory, comparative)
|
||||
domain_classification: Classification of query domain (academic, code, current_events)
|
||||
|
||||
Returns:
|
||||
Dictionary containing the structured query
|
||||
"""
|
||||
# Get primary domain and confidence
|
||||
primary_domain = domain_classification.get('primary_type', 'general')
|
||||
primary_confidence = domain_classification.get('confidence', 0.5)
|
||||
|
||||
# Get secondary domains
|
||||
secondary_domains = domain_classification.get('secondary_types', [])
|
||||
|
||||
# Determine domain flags
|
||||
is_academic = primary_domain == 'academic' or any(d['type'] == 'academic' for d in secondary_domains)
|
||||
is_code = primary_domain == 'code' or any(d['type'] == 'code' for d in secondary_domains)
|
||||
is_current_events = primary_domain == 'current_events' or any(d['type'] == 'current_events' for d in secondary_domains)
|
||||
|
||||
# Higher threshold for secondary domains to avoid false positives
|
||||
if primary_domain != 'academic' and any(d['type'] == 'academic' and d['confidence'] >= 0.3 for d in secondary_domains):
|
||||
is_academic = True
|
||||
|
||||
if primary_domain != 'code' and any(d['type'] == 'code' and d['confidence'] >= 0.3 for d in secondary_domains):
|
||||
is_code = True
|
||||
|
||||
if primary_domain != 'current_events' and any(d['type'] == 'current_events' and d['confidence'] >= 0.3 for d in secondary_domains):
|
||||
is_current_events = True
|
||||
|
||||
return {
|
||||
'original_query': original_query,
|
||||
'enhanced_query': enhanced_query,
|
||||
'type': type_classification.get('type', 'unknown'),
|
||||
'intent': type_classification.get('intent', 'research'),
|
||||
'entities': type_classification.get('entities', []),
|
||||
'domain': primary_domain,
|
||||
'domain_confidence': primary_confidence,
|
||||
'secondary_domains': secondary_domains,
|
||||
'classification_reasoning': domain_classification.get('reasoning', ''),
|
||||
'timestamp': None, # Will be filled in by the caller
|
||||
'is_current_events': is_current_events,
|
||||
'is_academic': is_academic,
|
||||
'is_code': is_code,
|
||||
'metadata': {
|
||||
'type_classification': type_classification,
|
||||
'domain_classification': domain_classification
|
||||
}
|
||||
}
|
||||
|
||||
def _structure_query(self, original_query: str, enhanced_query: str,
|
||||
classification: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Structure a query for downstream modules.
|
||||
Structure a query for downstream modules using keyword-based classification.
|
||||
This is a fallback method when LLM classification fails.
|
||||
|
||||
Args:
|
||||
original_query: The original user query
|
||||
|
@ -79,7 +149,7 @@ class QueryProcessor:
|
|||
Returns:
|
||||
Dictionary containing the structured query
|
||||
"""
|
||||
# Detect query types
|
||||
# Detect query types using keyword-based methods
|
||||
is_current_events = self._is_current_events_query(original_query, classification)
|
||||
is_academic = self._is_academic_query(original_query, classification)
|
||||
is_code = self._is_code_query(original_query, classification)
|
||||
|
@ -95,7 +165,8 @@ class QueryProcessor:
|
|||
'is_academic': is_academic,
|
||||
'is_code': is_code,
|
||||
'metadata': {
|
||||
'classification': classification
|
||||
'classification': classification,
|
||||
'classification_method': 'keyword' # Indicate this used the keyword-based method
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Binary file not shown.
|
@ -463,7 +463,20 @@ def get_progressive_report_synthesizer(model_name: Optional[str] = None) -> Prog
|
|||
global progressive_report_synthesizer
|
||||
|
||||
if model_name and model_name != progressive_report_synthesizer.model_name:
|
||||
progressive_report_synthesizer = ProgressiveReportSynthesizer(model_name)
|
||||
logger.info(f"Creating new progressive report synthesizer with model: {model_name}")
|
||||
try:
|
||||
previous_model = progressive_report_synthesizer.model_name
|
||||
progressive_report_synthesizer = ProgressiveReportSynthesizer(model_name)
|
||||
logger.info(f"Successfully changed progressive synthesizer model from {previous_model} to {model_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating new progressive report synthesizer with model {model_name}: {str(e)}")
|
||||
# Fall back to the existing synthesizer
|
||||
logger.info(f"Falling back to existing progressive synthesizer with model {progressive_report_synthesizer.model_name}")
|
||||
else:
|
||||
if model_name:
|
||||
logger.info(f"Using existing progressive report synthesizer with model: {model_name} (already initialized)")
|
||||
else:
|
||||
logger.info(f"Using existing progressive report synthesizer with default model: {progressive_report_synthesizer.model_name}")
|
||||
|
||||
return progressive_report_synthesizer
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from report.report_synthesis import ReportSynthesizer
|
|||
|
||||
async def test_model_provider_selection():
|
||||
"""Test that model provider selection works correctly."""
|
||||
logger.info("=== Testing basic model provider selection ===")
|
||||
# Initialize config
|
||||
config = Config()
|
||||
|
||||
|
@ -81,10 +82,249 @@ async def test_model_provider_selection():
|
|||
|
||||
logger.info(f"===== Test completed for {model_name} with provider {provider} =====\n")
|
||||
|
||||
async def test_provider_selection_stability():
|
||||
"""Test that provider selection remains stable across various scenarios."""
|
||||
logger.info("\n=== Testing provider selection stability ===")
|
||||
|
||||
# Test 1: Stability across multiple initializations with the same model
|
||||
logger.info("\nTest 1: Stability across multiple initializations with the same model")
|
||||
model_name = "llama-3.3-70b-versatile"
|
||||
provider = "groq"
|
||||
|
||||
# Create multiple synthesizers with the same model
|
||||
synthesizers = []
|
||||
for i in range(3):
|
||||
logger.info(f"Creating synthesizer {i+1} with model {model_name}")
|
||||
synthesizer = ReportSynthesizer(model_name=model_name)
|
||||
synthesizers.append(synthesizer)
|
||||
logger.info(f"Synthesizer {i+1} provider: {synthesizer.model_config.get('provider')}")
|
||||
|
||||
# Verify all synthesizers have the same provider
|
||||
providers = [s.model_config.get('provider') for s in synthesizers]
|
||||
logger.info(f"Providers across synthesizers: {providers}")
|
||||
assert all(p == provider for p in providers), "Provider not stable across multiple initializations"
|
||||
logger.info("✅ Provider stable across multiple initializations")
|
||||
|
||||
# Test 2: Stability when switching between models
|
||||
logger.info("\nTest 2: Stability when switching between models")
|
||||
model_configs = [
|
||||
{"name": "llama-3.3-70b-versatile", "provider": "groq"},
|
||||
{"name": "gemini-2.0-flash", "provider": "gemini"},
|
||||
{"name": "claude-3-opus-20240229", "provider": "anthropic"},
|
||||
{"name": "gpt-4-turbo", "provider": "openai"},
|
||||
]
|
||||
|
||||
# Test switching between models multiple times
|
||||
for _ in range(2): # Do two rounds of switching
|
||||
for model_config in model_configs:
|
||||
model_name = model_config["name"]
|
||||
expected_provider = model_config["provider"]
|
||||
|
||||
logger.info(f"Switching to model {model_name} with expected provider {expected_provider}")
|
||||
synthesizer = ReportSynthesizer(model_name=model_name)
|
||||
actual_provider = synthesizer.model_config.get('provider')
|
||||
|
||||
logger.info(f"Model: {model_name}, Expected provider: {expected_provider}, Actual provider: {actual_provider}")
|
||||
assert actual_provider == expected_provider, f"Provider mismatch for {model_name}: expected {expected_provider}, got {actual_provider}"
|
||||
|
||||
logger.info("✅ Provider selection stable when switching between models")
|
||||
|
||||
# Test 3: Stability with direct configuration changes
|
||||
logger.info("\nTest 3: Stability with direct configuration changes")
|
||||
test_model = "test-model-stability"
|
||||
|
||||
# Get the global config instance
|
||||
from config.config import config as global_config
|
||||
|
||||
# Save original config state
|
||||
original_models = global_config.config_data.get('models', {}).copy()
|
||||
|
||||
try:
|
||||
# Ensure models dict exists
|
||||
if 'models' not in global_config.config_data:
|
||||
global_config.config_data['models'] = {}
|
||||
|
||||
# Set up test model with groq provider
|
||||
global_config.config_data['models'][test_model] = {
|
||||
"provider": "groq",
|
||||
"model_name": test_model,
|
||||
"temperature": 0.5,
|
||||
"max_tokens": 2048,
|
||||
"top_p": 1.0
|
||||
}
|
||||
|
||||
# Create first synthesizer with groq provider
|
||||
logger.info(f"Creating first synthesizer with {test_model} using groq provider")
|
||||
synthesizer1 = ReportSynthesizer(model_name=test_model)
|
||||
provider1 = synthesizer1.model_config.get('provider')
|
||||
logger.info(f"Initial provider for {test_model}: {provider1}")
|
||||
|
||||
# Change the provider in the global config
|
||||
global_config.config_data['models'][test_model]["provider"] = "anthropic"
|
||||
|
||||
# Create second synthesizer with the updated config
|
||||
logger.info(f"Creating second synthesizer with {test_model} using anthropic provider")
|
||||
synthesizer2 = ReportSynthesizer(model_name=test_model)
|
||||
provider2 = synthesizer2.model_config.get('provider')
|
||||
logger.info(f"Updated provider for {test_model}: {provider2}")
|
||||
|
||||
# Verify the provider was updated
|
||||
assert provider1 == "groq", f"Initial provider should be groq, got {provider1}"
|
||||
assert provider2 == "anthropic", f"Updated provider should be anthropic, got {provider2}"
|
||||
logger.info("✅ Provider selection responds correctly to configuration changes")
|
||||
|
||||
# Test 4: Provider selection when using singleton vs. creating new instances
|
||||
logger.info("\nTest 4: Provider selection when using singleton vs. creating new instances")
|
||||
|
||||
from report.report_synthesis import get_report_synthesizer
|
||||
|
||||
# Set up a test model in the config
|
||||
test_model_singleton = "test-model-singleton"
|
||||
global_config.config_data['models'][test_model_singleton] = {
|
||||
"provider": "openai",
|
||||
"model_name": test_model_singleton,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1024
|
||||
}
|
||||
|
||||
# Get singleton instance with the test model
|
||||
logger.info(f"Getting singleton instance with {test_model_singleton}")
|
||||
singleton_synthesizer = get_report_synthesizer(model_name=test_model_singleton)
|
||||
singleton_provider = singleton_synthesizer.model_config.get('provider')
|
||||
logger.info(f"Singleton provider: {singleton_provider}")
|
||||
|
||||
# Create a new instance with the same model
|
||||
logger.info(f"Creating new instance with {test_model_singleton}")
|
||||
new_synthesizer = ReportSynthesizer(model_name=test_model_singleton)
|
||||
new_provider = new_synthesizer.model_config.get('provider')
|
||||
logger.info(f"New instance provider: {new_provider}")
|
||||
|
||||
# Verify both have the same provider
|
||||
assert singleton_provider == new_provider, f"Provider mismatch between singleton and new instance: {singleton_provider} vs {new_provider}"
|
||||
logger.info("✅ Provider selection consistent between singleton and new instances")
|
||||
|
||||
# Test 5: Edge case with invalid provider
|
||||
logger.info("\nTest 5: Edge case with invalid provider")
|
||||
|
||||
# Set up a test model with an invalid provider
|
||||
test_model_invalid = "test-model-invalid-provider"
|
||||
global_config.config_data['models'][test_model_invalid] = {
|
||||
"provider": "invalid_provider", # This provider doesn't exist
|
||||
"model_name": test_model_invalid,
|
||||
"temperature": 0.5
|
||||
}
|
||||
|
||||
# Create a synthesizer with the invalid provider model
|
||||
logger.info(f"Creating synthesizer with invalid provider for {test_model_invalid}")
|
||||
invalid_synthesizer = ReportSynthesizer(model_name=test_model_invalid)
|
||||
invalid_provider = invalid_synthesizer.model_config.get('provider')
|
||||
|
||||
# The provider should remain as specified in the config, even if invalid
|
||||
# This is important for error handling and debugging
|
||||
logger.info(f"Provider for invalid model: {invalid_provider}")
|
||||
assert invalid_provider == "invalid_provider", f"Invalid provider should be preserved, got {invalid_provider}"
|
||||
logger.info("✅ Invalid provider preserved in configuration")
|
||||
|
||||
# Test 6: Provider fallback mechanism
|
||||
logger.info("\nTest 6: Provider fallback mechanism")
|
||||
|
||||
# Create a model with no explicit provider
|
||||
test_model_no_provider = "test-model-no-provider"
|
||||
global_config.config_data['models'][test_model_no_provider] = {
|
||||
# No provider specified
|
||||
"model_name": test_model_no_provider,
|
||||
"temperature": 0.5
|
||||
}
|
||||
|
||||
# Create a synthesizer with this model
|
||||
logger.info(f"Creating synthesizer with no explicit provider for {test_model_no_provider}")
|
||||
no_provider_synthesizer = ReportSynthesizer(model_name=test_model_no_provider)
|
||||
|
||||
# The provider should be inferred based on the model name
|
||||
fallback_provider = no_provider_synthesizer.model_config.get('provider')
|
||||
logger.info(f"Fallback provider for model with no explicit provider: {fallback_provider}")
|
||||
|
||||
# Since our test model name doesn't match any known pattern, it should default to groq
|
||||
assert fallback_provider == "groq", f"Expected fallback to groq, got {fallback_provider}"
|
||||
logger.info("✅ Provider fallback mechanism works correctly")
|
||||
|
||||
finally:
|
||||
# Restore original config state
|
||||
global_config.config_data['models'] = original_models
|
||||
|
||||
async def test_provider_selection_after_config_reload():
|
||||
"""Test that provider selection remains stable after config reload."""
|
||||
logger.info("\n=== Testing provider selection after config reload ===")
|
||||
|
||||
# Get the global config instance
|
||||
from config.config import config as global_config
|
||||
from config.config import Config
|
||||
|
||||
# Save original config state
|
||||
original_models = global_config.config_data.get('models', {}).copy()
|
||||
original_config_path = global_config.config_path
|
||||
|
||||
try:
|
||||
# Set up a test model
|
||||
test_model = "test-model-config-reload"
|
||||
if 'models' not in global_config.config_data:
|
||||
global_config.config_data['models'] = {}
|
||||
|
||||
global_config.config_data['models'][test_model] = {
|
||||
"provider": "anthropic",
|
||||
"model_name": test_model,
|
||||
"temperature": 0.5
|
||||
}
|
||||
|
||||
# Create a synthesizer with this model
|
||||
logger.info(f"Creating synthesizer with {test_model} before config reload")
|
||||
synthesizer_before = ReportSynthesizer(model_name=test_model)
|
||||
provider_before = synthesizer_before.model_config.get('provider')
|
||||
logger.info(f"Provider before reload: {provider_before}")
|
||||
|
||||
# Simulate config reload by creating a new Config instance
|
||||
logger.info("Simulating config reload...")
|
||||
new_config = Config(config_path=original_config_path)
|
||||
|
||||
# Add the same test model to the new config
|
||||
if 'models' not in new_config.config_data:
|
||||
new_config.config_data['models'] = {}
|
||||
|
||||
new_config.config_data['models'][test_model] = {
|
||||
"provider": "anthropic", # Same provider
|
||||
"model_name": test_model,
|
||||
"temperature": 0.5
|
||||
}
|
||||
|
||||
# Temporarily replace the global config
|
||||
from config.config import config
|
||||
original_config = config
|
||||
import config.config
|
||||
config.config.config = new_config
|
||||
|
||||
# Create a new synthesizer after the reload
|
||||
logger.info(f"Creating synthesizer with {test_model} after config reload")
|
||||
synthesizer_after = ReportSynthesizer(model_name=test_model)
|
||||
provider_after = synthesizer_after.model_config.get('provider')
|
||||
logger.info(f"Provider after reload: {provider_after}")
|
||||
|
||||
# Verify the provider remains the same
|
||||
assert provider_before == provider_after, f"Provider changed after config reload: {provider_before} vs {provider_after}"
|
||||
logger.info("✅ Provider selection stable after config reload")
|
||||
|
||||
finally:
|
||||
# Restore original config state
|
||||
global_config.config_data['models'] = original_models
|
||||
# Restore original global config
|
||||
if 'original_config' in locals():
|
||||
config.config.config = original_config
|
||||
|
||||
async def main():
|
||||
"""Main function to run tests."""
|
||||
logger.info("Starting report synthesis tests...")
|
||||
await test_model_provider_selection()
|
||||
await test_provider_selection_stability()
|
||||
await test_provider_selection_after_config_reload()
|
||||
logger.info("All tests completed.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
# API settings
|
||||
SECRET_KEY=your-secret-key-here
|
||||
API_V1_STR=/api/v1
|
||||
|
||||
# Database settings
|
||||
DATABASE_URL=sqlite:///./sim-search.db
|
||||
|
||||
# CORS settings
|
||||
CORS_ORIGINS=http://localhost:3000,http://localhost:8000
|
||||
|
||||
# Sim-search settings
|
||||
SIM_SEARCH_PATH=/Volumes/SAM2/CODE/sim-search
|
||||
|
||||
# Default models for different detail levels
|
||||
DEFAULT_MODELS_BRIEF=llama-3.1-8b-instant
|
||||
DEFAULT_MODELS_STANDARD=llama-3.1-8b-instant
|
||||
DEFAULT_MODELS_DETAILED=llama-3.3-70b-versatile
|
||||
DEFAULT_MODELS_COMPREHENSIVE=llama-3.3-70b-versatile
|
|
@ -0,0 +1,165 @@
|
|||
# Sim-Search API
|
||||
|
||||
A FastAPI backend for the Sim-Search intelligent research system.
|
||||
|
||||
## Overview
|
||||
|
||||
This API provides a RESTful interface to the Sim-Search system, allowing for:
|
||||
|
||||
- Query processing and classification
|
||||
- Search execution across multiple engines
|
||||
- Report generation with different detail levels
|
||||
- User authentication and management
|
||||
|
||||
## Architecture
|
||||
|
||||
The API follows a layered architecture:
|
||||
|
||||
1. **API Layer**: FastAPI routes and endpoints
|
||||
2. **Service Layer**: Business logic and integration with Sim-Search
|
||||
3. **Data Layer**: Database models and session management
|
||||
|
||||
## Setup
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.8+
|
||||
- Sim-Search system installed and configured
|
||||
- API keys for search engines (if using external search engines)
|
||||
|
||||
### Installation
|
||||
|
||||
1. Clone the repository:
|
||||
|
||||
```bash
|
||||
git clone <repository-url>
|
||||
cd sim-search-api
|
||||
```
|
||||
|
||||
2. Create a virtual environment:
|
||||
|
||||
```bash
|
||||
python -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
```
|
||||
|
||||
3. Install dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
4. Create a `.env` file based on `.env.example`:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
5. Edit the `.env` file with your configuration settings.
|
||||
|
||||
### Database Setup
|
||||
|
||||
Initialize the database:
|
||||
|
||||
```bash
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
## Running the API
|
||||
|
||||
Start the API server:
|
||||
|
||||
```bash
|
||||
python run.py
|
||||
```
|
||||
|
||||
Or with custom settings:
|
||||
|
||||
```bash
|
||||
python run.py --host 0.0.0.0 --port 8000 --reload --debug
|
||||
```
|
||||
|
||||
## API Documentation
|
||||
|
||||
Once the server is running, you can access the API documentation at:
|
||||
|
||||
- Swagger UI: http://localhost:8000/docs
|
||||
- ReDoc: http://localhost:8000/redoc
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Authentication
|
||||
|
||||
- `POST /api/v1/auth/token`: Get an authentication token
|
||||
- `POST /api/v1/auth/register`: Register a new user
|
||||
|
||||
### Query Processing
|
||||
|
||||
- `POST /api/v1/query/process`: Process and enhance a user query
|
||||
- `POST /api/v1/query/classify`: Classify a query by type and intent
|
||||
|
||||
### Search Execution
|
||||
|
||||
- `POST /api/v1/search/execute`: Execute a search with optional parameters
|
||||
- `GET /api/v1/search/engines`: Get available search engines
|
||||
- `GET /api/v1/search/history`: Get user's search history
|
||||
- `GET /api/v1/search/{search_id}`: Get results for a specific search
|
||||
- `DELETE /api/v1/search/{search_id}`: Delete a search from history
|
||||
|
||||
### Report Generation
|
||||
|
||||
- `POST /api/v1/report/generate`: Generate a report from search results
|
||||
- `GET /api/v1/report/list`: Get a list of user's reports
|
||||
- `GET /api/v1/report/{report_id}`: Get a specific report
|
||||
- `DELETE /api/v1/report/{report_id}`: Delete a report
|
||||
- `GET /api/v1/report/{report_id}/download`: Download a report in specified format
|
||||
|
||||
## Development
|
||||
|
||||
### Project Structure
|
||||
|
||||
```
|
||||
sim-search-api/
|
||||
├── app/
|
||||
│ ├── api/
|
||||
│ │ ├── routes/
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ ├── query.py # Query processing endpoints
|
||||
│ │ │ ├── search.py # Search execution endpoints
|
||||
│ │ │ ├── report.py # Report generation endpoints
|
||||
│ │ │ └── auth.py # Authentication endpoints
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── dependencies.py # API dependencies (auth, rate limiting)
|
||||
│ ├── core/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── config.py # API configuration
|
||||
│ │ └── security.py # Security utilities
|
||||
│ ├── db/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── session.py # Database session
|
||||
│ │ └── models.py # Database models for reports, searches
|
||||
│ ├── schemas/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── query.py # Query schemas
|
||||
│ │ ├── search.py # Search result schemas
|
||||
│ │ └── report.py # Report schemas
|
||||
│ ├── services/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── query_service.py # Query processing service
|
||||
│ │ ├── search_service.py # Search execution service
|
||||
│ │ └── report_service.py # Report generation service
|
||||
│ └── main.py # FastAPI application
|
||||
├── alembic/ # Database migrations
|
||||
├── .env.example # Environment variables template
|
||||
└── requirements.txt # Dependencies
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
[MIT License](LICENSE)
|
|
@ -0,0 +1,102 @@
|
|||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = alembic
|
||||
|
||||
# template used to generate migration files
|
||||
# file_template = %%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python-dateutil library that can be
|
||||
# installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to dateutil.tz.gettz()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to alembic/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = sqlite:///./sim-search.db
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
|
@ -0,0 +1,86 @@
|
|||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Override sqlalchemy.url with the value from environment variable
|
||||
sqlalchemy_url = os.getenv("DATABASE_URL", "sqlite:///./sim-search.db")
|
||||
config.set_main_option("sqlalchemy.url", sqlalchemy_url)
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
from app.db.models import Base
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
|
@ -0,0 +1,24 @@
|
|||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = ${repr(up_revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
|
@ -0,0 +1,79 @@
|
|||
"""Initial migration
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
Create Date: 2025-03-20
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '001'
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create users table
|
||||
op.create_table(
|
||||
'users',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('hashed_password', sa.String(), nullable=False),
|
||||
sa.Column('full_name', sa.String(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
|
||||
sa.Column('is_superuser', sa.Boolean(), nullable=True, default=False),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('email')
|
||||
)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=True)
|
||||
|
||||
# Create searches table
|
||||
op.create_table(
|
||||
'searches',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('query', sa.String(), nullable=False),
|
||||
sa.Column('enhanced_query', sa.String(), nullable=True),
|
||||
sa.Column('query_type', sa.String(), nullable=True),
|
||||
sa.Column('engines', sa.String(), nullable=True),
|
||||
sa.Column('results_count', sa.Integer(), nullable=True, default=0),
|
||||
sa.Column('results', sqlite.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True, default=sa.func.current_timestamp()),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_searches_id'), 'searches', ['id'], unique=True)
|
||||
|
||||
# Create reports table
|
||||
op.create_table(
|
||||
'reports',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('search_id', sa.String(), nullable=True),
|
||||
sa.Column('title', sa.String(), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('detail_level', sa.String(), nullable=False, default='standard'),
|
||||
sa.Column('query_type', sa.String(), nullable=True),
|
||||
sa.Column('model_used', sa.String(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True, default=sa.func.current_timestamp()),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True, default=sa.func.current_timestamp(), onupdate=sa.func.current_timestamp()),
|
||||
sa.ForeignKeyConstraint(['search_id'], ['searches.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_reports_id'), 'reports', ['id'], unique=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f('ix_reports_id'), table_name='reports')
|
||||
op.drop_table('reports')
|
||||
op.drop_index(op.f('ix_searches_id'), table_name='searches')
|
||||
op.drop_table('searches')
|
||||
op.drop_index(op.f('ix_users_id'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_table('users')
|
|
@ -0,0 +1,120 @@
|
|||
"""
|
||||
API dependencies for the sim-search API.
|
||||
|
||||
This module provides common dependencies for the API routes.
|
||||
"""
|
||||
|
||||
from typing import Generator, Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import jwt, JWTError
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.security import verify_password
|
||||
from app.db.models import User
|
||||
from app.db.session import get_db
|
||||
from app.schemas.token import TokenPayload
|
||||
|
||||
# OAuth2 scheme for token authentication
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/token")
|
||||
|
||||
|
||||
def get_current_user(
|
||||
db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)
|
||||
) -> User:
|
||||
"""
|
||||
Get the current user from the token.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
User object
|
||||
|
||||
Raises:
|
||||
HTTPException: If the token is invalid or the user is not found
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
token_data = TokenPayload(**payload)
|
||||
except (JWTError, ValidationError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Could not validate credentials",
|
||||
)
|
||||
|
||||
user = db.query(User).filter(User.id == token_data.sub).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
"""
|
||||
Get the current active user.
|
||||
|
||||
Args:
|
||||
current_user: Current user
|
||||
|
||||
Returns:
|
||||
User object
|
||||
|
||||
Raises:
|
||||
HTTPException: If the user is inactive
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
|
||||
|
||||
def get_current_active_superuser(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
"""
|
||||
Get the current active superuser.
|
||||
|
||||
Args:
|
||||
current_user: Current user
|
||||
|
||||
Returns:
|
||||
User object
|
||||
|
||||
Raises:
|
||||
HTTPException: If the user is not a superuser
|
||||
"""
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="The user doesn't have enough privileges"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
"""
|
||||
Authenticate a user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
email: User email
|
||||
password: User password
|
||||
|
||||
Returns:
|
||||
User object if authentication is successful, None otherwise
|
||||
"""
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
if not user:
|
||||
return None
|
||||
if not verify_password(password, user.hashed_password):
|
||||
return None
|
||||
return user
|
|
@ -0,0 +1,98 @@
|
|||
"""
|
||||
Authentication routes for the sim-search API.
|
||||
|
||||
This module defines the routes for user authentication and registration.
|
||||
"""
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies import authenticate_user
|
||||
from app.core.config import settings
|
||||
from app.core.security import create_access_token, get_password_hash
|
||||
from app.db.models import User
|
||||
from app.db.session import get_db
|
||||
from app.schemas.token import Token
|
||||
from app.schemas.user import UserCreate, User as UserSchema
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/token", response_model=Token)
|
||||
async def login_for_access_token(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
OAuth2 compatible token login, get an access token for future requests.
|
||||
|
||||
Args:
|
||||
form_data: OAuth2 password request form
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Access token
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails
|
||||
"""
|
||||
user = authenticate_user(db, form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
subject=user.id, expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserSchema)
|
||||
async def register_user(
|
||||
user_in: UserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Register a new user.
|
||||
|
||||
Args:
|
||||
user_in: User creation data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created user
|
||||
|
||||
Raises:
|
||||
HTTPException: If a user with the same email already exists
|
||||
"""
|
||||
# Check if user with this email already exists
|
||||
user = db.query(User).filter(User.email == user_in.email).first()
|
||||
if user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="A user with this email already exists",
|
||||
)
|
||||
|
||||
# Create new user
|
||||
user = User(
|
||||
email=user_in.email,
|
||||
hashed_password=get_password_hash(user_in.password),
|
||||
full_name=user_in.full_name,
|
||||
is_active=user_in.is_active,
|
||||
is_superuser=user_in.is_superuser,
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return user
|
|
@ -0,0 +1,73 @@
|
|||
"""
|
||||
Query routes for the sim-search API.
|
||||
|
||||
This module defines the routes for query processing and classification.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies import get_current_active_user
|
||||
from app.db.models import User
|
||||
from app.db.session import get_db
|
||||
from app.schemas.query import QueryProcess, QueryClassify, ProcessedQuery
|
||||
from app.services.query_service import QueryService
|
||||
|
||||
router = APIRouter()
|
||||
query_service = QueryService()
|
||||
|
||||
|
||||
@router.post("/process", response_model=ProcessedQuery)
|
||||
async def process_query(
|
||||
query_in: QueryProcess,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Process a query to enhance and structure it.
|
||||
|
||||
Args:
|
||||
query_in: Query to process
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Processed query with structured information
|
||||
"""
|
||||
try:
|
||||
processed_query = await query_service.process_query(query_in.query)
|
||||
return processed_query
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error processing query: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/classify", response_model=ProcessedQuery)
|
||||
async def classify_query(
|
||||
query_in: QueryClassify,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Classify a query by type and intent.
|
||||
|
||||
Args:
|
||||
query_in: Query to classify
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Classified query with type and intent information
|
||||
"""
|
||||
try:
|
||||
classified_query = await query_service.classify_query(query_in.query)
|
||||
return classified_query
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error classifying query: {str(e)}",
|
||||
)
|
|
@ -0,0 +1,294 @@
|
|||
"""
|
||||
Report routes for the sim-search API.
|
||||
|
||||
This module defines the routes for report generation and management.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies import get_current_active_user
|
||||
from app.db.models import User, Report, Search
|
||||
from app.db.session import get_db
|
||||
from app.schemas.report import (
|
||||
ReportCreate, ReportUpdate, Report as ReportSchema,
|
||||
ReportList, ReportProgress, ReportDownload
|
||||
)
|
||||
from app.services.report_service import ReportService
|
||||
|
||||
router = APIRouter()
|
||||
report_service = ReportService()
|
||||
|
||||
# Dictionary to store report generation progress
|
||||
report_progress = {}
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ReportSchema)
|
||||
async def generate_report(
|
||||
report_in: ReportCreate,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Generate a report from search results.
|
||||
|
||||
Args:
|
||||
report_in: Report creation parameters
|
||||
background_tasks: FastAPI background tasks
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Generated report
|
||||
"""
|
||||
try:
|
||||
# Check if search_id is provided and exists
|
||||
search = None
|
||||
if report_in.search_id:
|
||||
search = db.query(Search).filter(
|
||||
Search.id == report_in.search_id,
|
||||
Search.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not search:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Search not found",
|
||||
)
|
||||
|
||||
# Create report record
|
||||
title = report_in.title or f"Report: {report_in.query}"
|
||||
report = Report(
|
||||
user_id=current_user.id,
|
||||
search_id=report_in.search_id,
|
||||
title=title,
|
||||
content="Report generation in progress...",
|
||||
detail_level=report_in.detail_level or "standard",
|
||||
query_type=report_in.query_type,
|
||||
model_used=report_in.model,
|
||||
)
|
||||
|
||||
db.add(report)
|
||||
db.commit()
|
||||
db.refresh(report)
|
||||
|
||||
# Initialize progress tracking
|
||||
report_progress[report.id] = {
|
||||
"progress": 0.0,
|
||||
"status": "Initializing report generation...",
|
||||
"current_chunk": 0,
|
||||
"total_chunks": 0,
|
||||
"current_report": "Report generation in progress...",
|
||||
}
|
||||
|
||||
# Generate report in background
|
||||
background_tasks.add_task(
|
||||
report_service.generate_report_background,
|
||||
report_id=report.id,
|
||||
report_in=report_in,
|
||||
search=search,
|
||||
db=db,
|
||||
progress_dict=report_progress,
|
||||
)
|
||||
|
||||
return report
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error generating report: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/list", response_model=ReportList)
|
||||
async def list_reports(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get a list of user's reports.
|
||||
|
||||
Args:
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of reports
|
||||
"""
|
||||
reports = db.query(Report).filter(Report.user_id == current_user.id).order_by(
|
||||
Report.created_at.desc()
|
||||
).offset(skip).limit(limit).all()
|
||||
|
||||
total = db.query(Report).filter(Report.user_id == current_user.id).count()
|
||||
|
||||
return {"reports": reports, "total": total}
|
||||
|
||||
|
||||
@router.get("/{report_id}", response_model=ReportSchema)
|
||||
async def get_report(
|
||||
report_id: str,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get a specific report.
|
||||
|
||||
Args:
|
||||
report_id: ID of the report
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Report
|
||||
|
||||
Raises:
|
||||
HTTPException: If the report is not found or doesn't belong to the user
|
||||
"""
|
||||
report = db.query(Report).filter(
|
||||
Report.id == report_id, Report.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not report:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Report not found",
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
|
||||
@router.delete("/{report_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_report(
|
||||
report_id: str,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> None:
|
||||
"""
|
||||
Delete a report.
|
||||
|
||||
Args:
|
||||
report_id: ID of the report to delete
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Raises:
|
||||
HTTPException: If the report is not found or doesn't belong to the user
|
||||
"""
|
||||
report = db.query(Report).filter(
|
||||
Report.id == report_id, Report.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not report:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Report not found",
|
||||
)
|
||||
|
||||
db.delete(report)
|
||||
db.commit()
|
||||
|
||||
|
||||
@router.get("/{report_id}/progress", response_model=ReportProgress)
|
||||
async def get_report_progress(
|
||||
report_id: str,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get the progress of a report generation.
|
||||
|
||||
Args:
|
||||
report_id: ID of the report
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Report generation progress
|
||||
|
||||
Raises:
|
||||
HTTPException: If the report is not found or doesn't belong to the user
|
||||
"""
|
||||
# Check if report exists and belongs to user
|
||||
report = db.query(Report).filter(
|
||||
Report.id == report_id, Report.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not report:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Report not found",
|
||||
)
|
||||
|
||||
# Get progress from progress dictionary
|
||||
progress_data = report_progress.get(report_id, {
|
||||
"progress": 1.0,
|
||||
"status": "Report generation complete",
|
||||
"current_chunk": 0,
|
||||
"total_chunks": 0,
|
||||
"current_report": None,
|
||||
})
|
||||
|
||||
return {
|
||||
"report_id": report_id,
|
||||
"progress": progress_data.get("progress", 1.0),
|
||||
"status": progress_data.get("status", "Report generation complete"),
|
||||
"current_chunk": progress_data.get("current_chunk", 0),
|
||||
"total_chunks": progress_data.get("total_chunks", 0),
|
||||
"current_report": progress_data.get("current_report", None),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{report_id}/download")
|
||||
async def download_report(
|
||||
report_id: str,
|
||||
format: str = "markdown",
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Download a report in the specified format.
|
||||
|
||||
Args:
|
||||
report_id: ID of the report
|
||||
format: Format of the report (markdown, html, pdf)
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Report file
|
||||
|
||||
Raises:
|
||||
HTTPException: If the report is not found or doesn't belong to the user
|
||||
"""
|
||||
report = db.query(Report).filter(
|
||||
Report.id == report_id, Report.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not report:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Report not found",
|
||||
)
|
||||
|
||||
# Generate file in the requested format
|
||||
try:
|
||||
file_path = await report_service.generate_report_file(report, format)
|
||||
|
||||
# Return file
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=f"report_{report_id}.{format}",
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error generating report file: {str(e)}",
|
||||
)
|
|
@ -0,0 +1,177 @@
|
|||
"""
|
||||
Search routes for the sim-search API.
|
||||
|
||||
This module defines the routes for search execution and history.
|
||||
"""
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies import get_current_active_user
|
||||
from app.db.models import User, Search
|
||||
from app.db.session import get_db
|
||||
from app.schemas.search import SearchExecute, SearchResults, SearchHistory, SearchHistoryList
|
||||
from app.services.search_service import SearchService
|
||||
|
||||
router = APIRouter()
|
||||
search_service = SearchService()
|
||||
|
||||
|
||||
@router.post("/execute", response_model=SearchResults)
|
||||
async def execute_search(
|
||||
search_in: SearchExecute,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Execute a search with the given parameters.
|
||||
|
||||
Args:
|
||||
search_in: Search parameters
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Search results
|
||||
"""
|
||||
try:
|
||||
# Get the structured query from the input
|
||||
structured_query = search_in.structured_query.model_dump() if search_in.structured_query else {}
|
||||
|
||||
# Print for debugging
|
||||
print(f"Executing search with structured_query: {structured_query}")
|
||||
|
||||
# Call the search service
|
||||
search_results = await search_service.execute_search(
|
||||
structured_query=structured_query, # Explicitly use keyword argument
|
||||
search_engines=search_in.search_engines,
|
||||
num_results=search_in.num_results,
|
||||
timeout=search_in.timeout,
|
||||
user_id=current_user.id,
|
||||
db=db,
|
||||
)
|
||||
return search_results
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error executing search: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/engines", response_model=List[str])
|
||||
async def get_available_search_engines(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Get a list of available search engines.
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
List of available search engine names
|
||||
"""
|
||||
try:
|
||||
engines = await search_service.get_available_search_engines()
|
||||
return engines
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error getting available search engines: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history", response_model=SearchHistoryList)
|
||||
async def get_search_history(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get the user's search history.
|
||||
|
||||
Args:
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of search history records
|
||||
"""
|
||||
searches = db.query(Search).filter(Search.user_id == current_user.id).order_by(
|
||||
Search.created_at.desc()
|
||||
).offset(skip).limit(limit).all()
|
||||
|
||||
total = db.query(Search).filter(Search.user_id == current_user.id).count()
|
||||
|
||||
return {"searches": searches, "total": total}
|
||||
|
||||
|
||||
@router.get("/{search_id}", response_model=SearchResults)
|
||||
async def get_search_results(
|
||||
search_id: str,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get results for a specific search.
|
||||
|
||||
Args:
|
||||
search_id: ID of the search
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Search results
|
||||
|
||||
Raises:
|
||||
HTTPException: If the search is not found or doesn't belong to the user
|
||||
"""
|
||||
search = db.query(Search).filter(
|
||||
Search.id == search_id, Search.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not search:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Search not found",
|
||||
)
|
||||
|
||||
return await search_service.get_search_results(search)
|
||||
|
||||
|
||||
@router.delete("/{search_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_search(
|
||||
search_id: str,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> None:
|
||||
"""
|
||||
Delete a search from history.
|
||||
|
||||
Args:
|
||||
search_id: ID of the search to delete
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Raises:
|
||||
HTTPException: If the search is not found or doesn't belong to the user
|
||||
"""
|
||||
search = db.query(Search).filter(
|
||||
Search.id == search_id, Search.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not search:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Search not found",
|
||||
)
|
||||
|
||||
db.delete(search)
|
||||
db.commit()
|
|
@ -0,0 +1,65 @@
|
|||
"""
|
||||
Configuration settings for the sim-search API.
|
||||
|
||||
This module defines the settings for the API, loaded from environment variables.
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
|
||||
from pydantic import AnyHttpUrl, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Settings for the sim-search API."""
|
||||
|
||||
# API settings
|
||||
API_V1_STR: str = "/api/v1"
|
||||
PROJECT_NAME: str = "Sim-Search API"
|
||||
PROJECT_DESCRIPTION: str = "API for the Sim-Search intelligent research system"
|
||||
VERSION: str = "0.1.0"
|
||||
|
||||
# Security settings
|
||||
SECRET_KEY: str = os.getenv("SECRET_KEY", secrets.token_urlsafe(32))
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 7 # 7 days
|
||||
|
||||
# CORS settings
|
||||
CORS_ORIGINS: List[str] = ["*"]
|
||||
|
||||
@field_validator("CORS_ORIGINS", mode="before")
|
||||
@classmethod
|
||||
def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]:
|
||||
"""Parse CORS origins from string or list."""
|
||||
if isinstance(v, str) and not v.startswith("["):
|
||||
return [i.strip() for i in v.split(",")]
|
||||
elif isinstance(v, (list, str)):
|
||||
return v
|
||||
raise ValueError(v)
|
||||
|
||||
# Database settings
|
||||
SQLALCHEMY_DATABASE_URI: str = os.getenv(
|
||||
"DATABASE_URL", f"sqlite:///./sim-search.db"
|
||||
)
|
||||
|
||||
# Sim-search settings
|
||||
SIM_SEARCH_PATH: str = os.getenv("SIM_SEARCH_PATH", "/Volumes/SAM2/CODE/sim-search")
|
||||
|
||||
# Default models for different detail levels
|
||||
DEFAULT_MODELS: Dict[str, str] = {
|
||||
"brief": "llama-3.1-8b-instant",
|
||||
"standard": "llama-3.1-8b-instant",
|
||||
"detailed": "llama-3.3-70b-versatile",
|
||||
"comprehensive": "llama-3.3-70b-versatile"
|
||||
}
|
||||
|
||||
model_config = {
|
||||
"case_sensitive": True,
|
||||
"env_file": ".env",
|
||||
}
|
||||
|
||||
|
||||
# Create settings instance
|
||||
settings = Settings()
|
|
@ -0,0 +1,72 @@
|
|||
"""
|
||||
Security utilities for the sim-search API.
|
||||
|
||||
This module provides utilities for password hashing, JWT token generation,
|
||||
and token validation.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""
|
||||
Verify a password against a hash.
|
||||
|
||||
Args:
|
||||
plain_password: Plain text password
|
||||
hashed_password: Hashed password
|
||||
|
||||
Returns:
|
||||
True if the password matches the hash, False otherwise
|
||||
"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""
|
||||
Hash a password.
|
||||
|
||||
Args:
|
||||
password: Plain text password
|
||||
|
||||
Returns:
|
||||
Hashed password
|
||||
"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT access token.
|
||||
|
||||
Args:
|
||||
subject: Subject of the token (usually user ID)
|
||||
expires_delta: Optional expiration time delta
|
||||
|
||||
Returns:
|
||||
JWT token as a string
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
to_encode = {"exp": expire, "sub": str(subject)}
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
|
@ -0,0 +1,74 @@
|
|||
"""
|
||||
Database models for the sim-search API.
|
||||
|
||||
This module defines the SQLAlchemy ORM models for the database.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, ForeignKey, DateTime, Integer, JSON, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
|
||||
def generate_uuid() -> str:
|
||||
"""Generate a UUID string."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""User model."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(String, primary_key=True, index=True, default=generate_uuid)
|
||||
email = Column(String, unique=True, index=True, nullable=False)
|
||||
hashed_password = Column(String, nullable=False)
|
||||
full_name = Column(String, nullable=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_superuser = Column(Boolean, default=False)
|
||||
|
||||
searches = relationship("Search", back_populates="user")
|
||||
reports = relationship("Report", back_populates="user")
|
||||
|
||||
|
||||
class Search(Base):
|
||||
"""Search model."""
|
||||
|
||||
__tablename__ = "searches"
|
||||
|
||||
id = Column(String, primary_key=True, index=True, default=generate_uuid)
|
||||
user_id = Column(String, ForeignKey("users.id"))
|
||||
query = Column(String, nullable=False)
|
||||
enhanced_query = Column(String, nullable=True)
|
||||
query_type = Column(String, nullable=True)
|
||||
engines = Column(String, nullable=True) # Comma-separated list
|
||||
results_count = Column(Integer, default=0)
|
||||
results = Column(JSON, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
|
||||
user = relationship("User", back_populates="searches")
|
||||
reports = relationship("Report", back_populates="search")
|
||||
|
||||
|
||||
class Report(Base):
|
||||
"""Report model."""
|
||||
|
||||
__tablename__ = "reports"
|
||||
|
||||
id = Column(String, primary_key=True, index=True, default=generate_uuid)
|
||||
user_id = Column(String, ForeignKey("users.id"))
|
||||
search_id = Column(String, ForeignKey("searches.id"), nullable=True)
|
||||
title = Column(String, nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
detail_level = Column(String, nullable=False, default="standard")
|
||||
query_type = Column(String, nullable=True)
|
||||
model_used = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow)
|
||||
|
||||
user = relationship("User", back_populates="reports")
|
||||
search = relationship("Search", back_populates="reports")
|
|
@ -0,0 +1,38 @@
|
|||
"""
|
||||
Database session management for the sim-search API.
|
||||
|
||||
This module provides utilities for creating and managing database sessions.
|
||||
"""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Create SQLAlchemy engine
|
||||
engine = create_engine(
|
||||
settings.SQLALCHEMY_DATABASE_URI,
|
||||
pool_pre_ping=True,
|
||||
connect_args={"check_same_thread": False} if settings.SQLALCHEMY_DATABASE_URI.startswith("sqlite") else {},
|
||||
)
|
||||
|
||||
# Create session factory
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# Create base class for models
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_db():
|
||||
"""
|
||||
Get a database session.
|
||||
|
||||
Yields:
|
||||
SQLAlchemy session
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
|
@ -0,0 +1,80 @@
|
|||
"""
|
||||
Main FastAPI application for the sim-search API.
|
||||
|
||||
This module defines the FastAPI application and includes all routes.
|
||||
"""
|
||||
|
||||
import os
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from app.api.routes import query, search, report, auth
|
||||
from app.core.config import settings
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
description=settings.PROJECT_DESCRIPTION,
|
||||
version=settings.VERSION,
|
||||
docs_url=None, # Disable default docs
|
||||
redoc_url=None, # Disable default redoc
|
||||
)
|
||||
|
||||
# Set up CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
app.include_router(auth.router, prefix=f"{settings.API_V1_STR}/auth", tags=["Authentication"])
|
||||
app.include_router(query.router, prefix=f"{settings.API_V1_STR}/query", tags=["Query Processing"])
|
||||
app.include_router(search.router, prefix=f"{settings.API_V1_STR}/search", tags=["Search Execution"])
|
||||
app.include_router(report.router, prefix=f"{settings.API_V1_STR}/report", tags=["Report Generation"])
|
||||
|
||||
# Custom OpenAPI and documentation endpoints
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
async def custom_swagger_ui_html():
|
||||
"""Custom Swagger UI documentation."""
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=f"{settings.API_V1_STR}/openapi.json",
|
||||
title=f"{settings.PROJECT_NAME} - Swagger UI",
|
||||
oauth2_redirect_url=f"{settings.API_V1_STR}/docs/oauth2-redirect",
|
||||
swagger_js_url="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js",
|
||||
swagger_css_url="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css",
|
||||
)
|
||||
|
||||
@app.get(f"{settings.API_V1_STR}/openapi.json", include_in_schema=False)
|
||||
async def get_open_api_endpoint():
|
||||
"""Return OpenAPI schema."""
|
||||
return get_openapi(
|
||||
title=settings.PROJECT_NAME,
|
||||
version=settings.VERSION,
|
||||
description=settings.PROJECT_DESCRIPTION,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
@app.get("/", tags=["Status"])
|
||||
async def root():
|
||||
"""Root endpoint to check API status."""
|
||||
return {
|
||||
"status": "online",
|
||||
"version": settings.VERSION,
|
||||
"project": settings.PROJECT_NAME,
|
||||
"docs": "/docs"
|
||||
}
|
||||
|
||||
# Initialize components on startup
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize components on startup."""
|
||||
# Import here to avoid circular imports
|
||||
from app.services.report_service import initialize_report_generator
|
||||
|
||||
# Initialize report generator
|
||||
await initialize_report_generator()
|
|
@ -0,0 +1,99 @@
|
|||
"""
|
||||
Query schemas for the sim-search API.
|
||||
|
||||
This module defines the Pydantic models for query-related operations.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class QueryBase(BaseModel):
|
||||
"""Base query schema."""
|
||||
|
||||
query: str
|
||||
|
||||
|
||||
class QueryProcess(QueryBase):
|
||||
"""Query processing schema."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QueryClassify(QueryBase):
|
||||
"""Query classification schema."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SubQuestion(BaseModel):
|
||||
"""Sub-question schema."""
|
||||
|
||||
sub_question: str
|
||||
aspect: str
|
||||
priority: float
|
||||
|
||||
|
||||
class StructuredQuery(BaseModel):
|
||||
"""Structured query schema."""
|
||||
|
||||
original_query: str
|
||||
enhanced_query: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
intent: Optional[str] = None
|
||||
domain: Optional[str] = None
|
||||
confidence: Optional[float] = None
|
||||
reasoning: Optional[str] = None
|
||||
entities: Optional[List[str]] = None
|
||||
sub_questions: Optional[List[SubQuestion]] = None
|
||||
search_queries: Optional[Dict[str, str]] = None
|
||||
is_academic: Optional[bool] = None
|
||||
is_code: Optional[bool] = None
|
||||
is_current_events: Optional[bool] = None
|
||||
|
||||
|
||||
class ProcessedQuery(BaseModel):
|
||||
"""Processed query schema."""
|
||||
|
||||
original_query: str
|
||||
structured_query: StructuredQuery
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"original_query": "What are the latest advancements in quantum computing?",
|
||||
"structured_query": {
|
||||
"original_query": "What are the latest advancements in quantum computing?",
|
||||
"enhanced_query": "What are the recent breakthroughs and developments in quantum computing technology, algorithms, and applications in the past 2 years?",
|
||||
"type": "exploratory",
|
||||
"intent": "research",
|
||||
"domain": "academic",
|
||||
"confidence": 0.95,
|
||||
"reasoning": "This query is asking about recent developments in a scientific field, which is typical of academic research.",
|
||||
"entities": ["quantum computing", "advancements"],
|
||||
"sub_questions": [
|
||||
{
|
||||
"sub_question": "What are the latest hardware advancements in quantum computing?",
|
||||
"aspect": "hardware",
|
||||
"priority": 0.9
|
||||
},
|
||||
{
|
||||
"sub_question": "What are the recent algorithmic breakthroughs in quantum computing?",
|
||||
"aspect": "algorithms",
|
||||
"priority": 0.8
|
||||
}
|
||||
],
|
||||
"search_queries": {
|
||||
"google": "latest advancements in quantum computing 2024",
|
||||
"scholar": "recent quantum computing breakthroughs",
|
||||
"arxiv": "quantum computing hardware algorithms"
|
||||
},
|
||||
"is_academic": True,
|
||||
"is_code": False,
|
||||
"is_current_events": False
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
"""
|
||||
Report schemas for the sim-search API.
|
||||
|
||||
This module defines the Pydantic models for report-related operations.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ReportBase(BaseModel):
|
||||
"""Base report schema."""
|
||||
|
||||
title: Optional[str] = None
|
||||
detail_level: Optional[str] = "standard"
|
||||
query_type: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
|
||||
class ReportCreate(ReportBase):
|
||||
"""Report creation schema."""
|
||||
|
||||
search_id: Optional[str] = None
|
||||
search_results: Optional[List[Dict[str, Any]]] = None
|
||||
query: str
|
||||
token_budget: Optional[int] = None
|
||||
chunk_size: Optional[int] = None
|
||||
overlap_size: Optional[int] = None
|
||||
|
||||
|
||||
class ReportUpdate(ReportBase):
|
||||
"""Report update schema."""
|
||||
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class ReportInDBBase(ReportBase):
|
||||
"""Base report in DB schema."""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
search_id: Optional[str] = None
|
||||
content: str
|
||||
model_used: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class Report(ReportInDBBase):
|
||||
"""Report schema."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ReportList(BaseModel):
|
||||
"""Report list schema."""
|
||||
|
||||
reports: List[Report]
|
||||
total: int
|
||||
|
||||
|
||||
class ReportProgress(BaseModel):
|
||||
"""Report generation progress schema."""
|
||||
|
||||
report_id: str
|
||||
progress: float
|
||||
status: str
|
||||
current_chunk: Optional[int] = None
|
||||
total_chunks: Optional[int] = None
|
||||
current_report: Optional[str] = None
|
||||
|
||||
|
||||
class ReportDownload(BaseModel):
|
||||
"""Report download schema."""
|
||||
|
||||
report_id: str
|
||||
format: str = "markdown" # markdown, html, pdf
|
|
@ -0,0 +1,75 @@
|
|||
"""
|
||||
Search schemas for the sim-search API.
|
||||
|
||||
This module defines the Pydantic models for search-related operations.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.schemas.query import StructuredQuery
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""Search result schema."""
|
||||
|
||||
title: str
|
||||
url: str
|
||||
snippet: str
|
||||
source: str
|
||||
score: Optional[float] = None
|
||||
authors: Optional[Any] = None # Can be string or list of strings
|
||||
year: Optional[str] = None
|
||||
pdf_url: Optional[str] = None
|
||||
arxiv_id: Optional[str] = None
|
||||
categories: Optional[List[str]] = None
|
||||
published_date: Optional[str] = None
|
||||
updated_date: Optional[str] = None
|
||||
full_text: Optional[str] = None
|
||||
|
||||
|
||||
class SearchExecute(BaseModel):
|
||||
"""Search execution schema."""
|
||||
|
||||
structured_query: StructuredQuery
|
||||
search_engines: Optional[List[str]] = None
|
||||
num_results: Optional[int] = 10
|
||||
timeout: Optional[int] = 30
|
||||
|
||||
|
||||
class SearchResults(BaseModel):
|
||||
"""Search results schema."""
|
||||
|
||||
search_id: str
|
||||
query: str
|
||||
enhanced_query: Optional[str] = None
|
||||
results: Dict[str, List[SearchResult]]
|
||||
total_results: int
|
||||
execution_time: float
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class SearchHistory(BaseModel):
|
||||
"""Search history schema."""
|
||||
|
||||
id: str
|
||||
query: str
|
||||
enhanced_query: Optional[str] = None
|
||||
query_type: Optional[str] = None
|
||||
engines: Optional[str] = None
|
||||
results_count: int
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SearchHistoryList(BaseModel):
|
||||
"""Search history list schema."""
|
||||
|
||||
searches: List[SearchHistory]
|
||||
total: int
|
|
@ -0,0 +1,28 @@
|
|||
"""
|
||||
Token schemas for the sim-search API.
|
||||
|
||||
This module defines the Pydantic models for token-related operations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
"""Token schema."""
|
||||
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""Token payload schema."""
|
||||
|
||||
sub: Optional[str] = None
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
"""Token data schema."""
|
||||
|
||||
username: Optional[str] = None
|
|
@ -0,0 +1,52 @@
|
|||
"""
|
||||
User schemas for the sim-search API.
|
||||
|
||||
This module defines the Pydantic models for user-related operations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
"""Base user schema."""
|
||||
|
||||
email: Optional[EmailStr] = None
|
||||
is_active: Optional[bool] = True
|
||||
is_superuser: bool = False
|
||||
full_name: Optional[str] = None
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
"""User creation schema."""
|
||||
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class UserUpdate(UserBase):
|
||||
"""User update schema."""
|
||||
|
||||
password: Optional[str] = None
|
||||
|
||||
|
||||
class UserInDBBase(UserBase):
|
||||
"""Base user in DB schema."""
|
||||
|
||||
id: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class User(UserInDBBase):
|
||||
"""User schema."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UserInDB(UserInDBBase):
|
||||
"""User in DB schema."""
|
||||
|
||||
hashed_password: str
|
|
@ -0,0 +1,85 @@
|
|||
"""
|
||||
Query service for the sim-search API.
|
||||
|
||||
This module provides services for query processing and classification.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Add sim-search to the python path
|
||||
sim_search_path = Path(settings.SIM_SEARCH_PATH)
|
||||
sys.path.append(str(sim_search_path))
|
||||
|
||||
# Import sim-search components
|
||||
from query.query_processor import QueryProcessor
|
||||
from query.llm_interface import LLMInterface
|
||||
|
||||
|
||||
class QueryService:
|
||||
"""
|
||||
Service for query processing and classification.
|
||||
|
||||
This class provides methods to process and classify queries using
|
||||
the sim-search query processing functionality.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the query service."""
|
||||
self.query_processor = QueryProcessor()
|
||||
self.llm_interface = LLMInterface()
|
||||
|
||||
async def process_query(self, query: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a query to enhance and structure it.
|
||||
|
||||
Args:
|
||||
query: Query to process
|
||||
|
||||
Returns:
|
||||
Processed query with structured information
|
||||
"""
|
||||
# Process the query using the sim-search query processor
|
||||
structured_query = await self.query_processor.process_query(query)
|
||||
|
||||
# Format the response
|
||||
return {
|
||||
"original_query": query,
|
||||
"structured_query": structured_query
|
||||
}
|
||||
|
||||
async def classify_query(self, query: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Classify a query by type and intent.
|
||||
|
||||
Args:
|
||||
query: Query to classify
|
||||
|
||||
Returns:
|
||||
Classified query with type and intent information
|
||||
"""
|
||||
# Classify the query using the sim-search LLM interface
|
||||
classification = await self.llm_interface.classify_query_domain(query)
|
||||
|
||||
# Create a structured query with the classification
|
||||
structured_query = {
|
||||
"original_query": query,
|
||||
"type": classification.get("type"),
|
||||
"intent": classification.get("intent"),
|
||||
"domain": classification.get("domain"),
|
||||
"confidence": classification.get("confidence"),
|
||||
"reasoning": classification.get("reasoning"),
|
||||
"is_academic": classification.get("is_academic", False),
|
||||
"is_code": classification.get("is_code", False),
|
||||
"is_current_events": classification.get("is_current_events", False)
|
||||
}
|
||||
|
||||
# Format the response
|
||||
return {
|
||||
"original_query": query,
|
||||
"structured_query": structured_query
|
||||
}
|
|
@ -0,0 +1,355 @@
|
|||
"""
|
||||
Report service for the sim-search API.
|
||||
|
||||
This module provides services for report generation and management.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db.models import Search, Report
|
||||
from app.schemas.report import ReportCreate
|
||||
|
||||
# Add sim-search to the python path
|
||||
sim_search_path = Path(settings.SIM_SEARCH_PATH)
|
||||
sys.path.append(str(sim_search_path))
|
||||
|
||||
# Import sim-search components
|
||||
from report.report_generator import get_report_generator, initialize_report_generator
|
||||
from report.report_detail_levels import get_report_detail_level_manager
|
||||
from app.services.search_service import SearchService
|
||||
|
||||
|
||||
class ReportService:
|
||||
"""
|
||||
Service for report generation and management.
|
||||
|
||||
This class provides methods to generate and manage reports using
|
||||
the sim-search report generation functionality.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the report service."""
|
||||
self.report_generator = None
|
||||
self.detail_level_manager = get_report_detail_level_manager()
|
||||
self.search_service = SearchService()
|
||||
self.temp_dir = Path(tempfile.gettempdir()) / "sim-search-api"
|
||||
self.temp_dir.mkdir(exist_ok=True)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the report generator."""
|
||||
await initialize_report_generator()
|
||||
self.report_generator = get_report_generator()
|
||||
|
||||
async def generate_report_background(
|
||||
self,
|
||||
report_id: str,
|
||||
report_in: ReportCreate,
|
||||
search: Optional[Search] = None,
|
||||
db: Optional[Session] = None,
|
||||
progress_dict: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate a report in the background.
|
||||
|
||||
Args:
|
||||
report_id: ID of the report
|
||||
report_in: Report creation parameters
|
||||
search: Search record
|
||||
db: Database session
|
||||
progress_dict: Dictionary to store progress information
|
||||
"""
|
||||
try:
|
||||
# Initialize report generator if not already initialized
|
||||
if self.report_generator is None:
|
||||
await self.initialize()
|
||||
|
||||
# Get search results
|
||||
search_results = []
|
||||
if search:
|
||||
# Use search results from the database
|
||||
search_results = search.results
|
||||
elif report_in.search_results:
|
||||
# Use search results provided in the request
|
||||
search_results = report_in.search_results
|
||||
else:
|
||||
# Execute a new search
|
||||
structured_query = {
|
||||
"original_query": report_in.query,
|
||||
"enhanced_query": report_in.query,
|
||||
}
|
||||
|
||||
search_results_dict = await self.search_service.execute_search(
|
||||
structured_query=structured_query,
|
||||
num_results=10,
|
||||
)
|
||||
|
||||
# Flatten search results
|
||||
for engine_results in search_results_dict["results"].values():
|
||||
search_results.extend(engine_results)
|
||||
|
||||
# Set up progress tracking
|
||||
if progress_dict is not None:
|
||||
def progress_callback(current_progress, total_chunks, current_report):
|
||||
if report_id in progress_dict:
|
||||
progress_dict[report_id] = {
|
||||
"progress": current_progress,
|
||||
"status": f"Processing chunk {int(current_progress * total_chunks)}/{total_chunks}...",
|
||||
"current_chunk": int(current_progress * total_chunks),
|
||||
"total_chunks": total_chunks,
|
||||
"current_report": current_report,
|
||||
}
|
||||
|
||||
self.report_generator.set_progress_callback(progress_callback)
|
||||
|
||||
# Set detail level
|
||||
if report_in.detail_level:
|
||||
self.report_generator.set_detail_level(report_in.detail_level)
|
||||
|
||||
# Set model if provided
|
||||
if report_in.model:
|
||||
self.report_generator.set_model(report_in.model)
|
||||
|
||||
# Generate report
|
||||
report_content = await self.report_generator.generate_report(
|
||||
search_results=search_results,
|
||||
query=report_in.query,
|
||||
token_budget=report_in.token_budget,
|
||||
chunk_size=report_in.chunk_size,
|
||||
overlap_size=report_in.overlap_size,
|
||||
detail_level=report_in.detail_level,
|
||||
query_type=report_in.query_type,
|
||||
)
|
||||
|
||||
# Update report in database
|
||||
if db:
|
||||
report = db.query(Report).filter(Report.id == report_id).first()
|
||||
if report:
|
||||
report.content = report_content
|
||||
report.model_used = self.report_generator.model_name
|
||||
db.commit()
|
||||
|
||||
# Update progress
|
||||
if progress_dict is not None and report_id in progress_dict:
|
||||
progress_dict[report_id] = {
|
||||
"progress": 1.0,
|
||||
"status": "Report generation complete",
|
||||
"current_chunk": 0,
|
||||
"total_chunks": 0,
|
||||
"current_report": None,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Update progress with error
|
||||
if progress_dict is not None and report_id in progress_dict:
|
||||
progress_dict[report_id] = {
|
||||
"progress": 1.0,
|
||||
"status": f"Error generating report: {str(e)}",
|
||||
"current_chunk": 0,
|
||||
"total_chunks": 0,
|
||||
"current_report": None,
|
||||
}
|
||||
|
||||
# Update report in database with error
|
||||
if db:
|
||||
report = db.query(Report).filter(Report.id == report_id).first()
|
||||
if report:
|
||||
report.content = f"Error generating report: {str(e)}"
|
||||
db.commit()
|
||||
|
||||
# Re-raise the exception
|
||||
raise
|
||||
|
||||
async def generate_report_file(self, report: Report, format: str = "markdown") -> str:
|
||||
"""
|
||||
Generate a report file in the specified format.
|
||||
|
||||
Args:
|
||||
report: Report record
|
||||
format: Format of the report (markdown, html, pdf)
|
||||
|
||||
Returns:
|
||||
Path to the generated file
|
||||
"""
|
||||
# Create a temporary file
|
||||
file_path = self.temp_dir / f"report_{report.id}.{format}"
|
||||
|
||||
# Write the report content to the file
|
||||
if format == "markdown":
|
||||
with open(file_path, "w") as f:
|
||||
f.write(report.content)
|
||||
elif format == "html":
|
||||
# Convert markdown to HTML
|
||||
import markdown
|
||||
html_content = markdown.markdown(report.content)
|
||||
|
||||
# Add HTML wrapper
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>{report.title}</title>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<style>
|
||||
body {{
|
||||
font-family: Arial, sans-serif;
|
||||
line-height: 1.6;
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
}}
|
||||
h1, h2, h3, h4, h5, h6 {{
|
||||
margin-top: 1.5em;
|
||||
margin-bottom: 0.5em;
|
||||
}}
|
||||
a {{
|
||||
color: #0366d6;
|
||||
text-decoration: none;
|
||||
}}
|
||||
a:hover {{
|
||||
text-decoration: underline;
|
||||
}}
|
||||
pre {{
|
||||
background-color: #f6f8fa;
|
||||
border-radius: 3px;
|
||||
padding: 16px;
|
||||
overflow: auto;
|
||||
}}
|
||||
code {{
|
||||
background-color: #f6f8fa;
|
||||
border-radius: 3px;
|
||||
padding: 0.2em 0.4em;
|
||||
font-family: monospace;
|
||||
}}
|
||||
blockquote {{
|
||||
border-left: 4px solid #dfe2e5;
|
||||
padding-left: 16px;
|
||||
margin-left: 0;
|
||||
color: #6a737d;
|
||||
}}
|
||||
table {{
|
||||
border-collapse: collapse;
|
||||
width: 100%;
|
||||
}}
|
||||
table, th, td {{
|
||||
border: 1px solid #dfe2e5;
|
||||
}}
|
||||
th, td {{
|
||||
padding: 8px 16px;
|
||||
text-align: left;
|
||||
}}
|
||||
tr:nth-child(even) {{
|
||||
background-color: #f6f8fa;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
{html_content}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
f.write(html_content)
|
||||
elif format == "pdf":
|
||||
# Convert markdown to PDF
|
||||
try:
|
||||
import markdown
|
||||
from weasyprint import HTML
|
||||
|
||||
# Convert markdown to HTML
|
||||
html_content = markdown.markdown(report.content)
|
||||
|
||||
# Add HTML wrapper
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>{report.title}</title>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<style>
|
||||
body {{
|
||||
font-family: Arial, sans-serif;
|
||||
line-height: 1.6;
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
}}
|
||||
h1, h2, h3, h4, h5, h6 {{
|
||||
margin-top: 1.5em;
|
||||
margin-bottom: 0.5em;
|
||||
}}
|
||||
a {{
|
||||
color: #0366d6;
|
||||
text-decoration: none;
|
||||
}}
|
||||
pre {{
|
||||
background-color: #f6f8fa;
|
||||
border-radius: 3px;
|
||||
padding: 16px;
|
||||
overflow: auto;
|
||||
}}
|
||||
code {{
|
||||
background-color: #f6f8fa;
|
||||
border-radius: 3px;
|
||||
padding: 0.2em 0.4em;
|
||||
font-family: monospace;
|
||||
}}
|
||||
blockquote {{
|
||||
border-left: 4px solid #dfe2e5;
|
||||
padding-left: 16px;
|
||||
margin-left: 0;
|
||||
color: #6a737d;
|
||||
}}
|
||||
table {{
|
||||
border-collapse: collapse;
|
||||
width: 100%;
|
||||
}}
|
||||
table, th, td {{
|
||||
border: 1px solid #dfe2e5;
|
||||
}}
|
||||
th, td {{
|
||||
padding: 8px 16px;
|
||||
text-align: left;
|
||||
}}
|
||||
tr:nth-child(even) {{
|
||||
background-color: #f6f8fa;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
{html_content}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
# Create a temporary HTML file
|
||||
html_file_path = self.temp_dir / f"report_{report.id}.html"
|
||||
with open(html_file_path, "w") as f:
|
||||
f.write(html_content)
|
||||
|
||||
# Convert HTML to PDF
|
||||
HTML(filename=str(html_file_path)).write_pdf(str(file_path))
|
||||
|
||||
# Remove temporary HTML file
|
||||
html_file_path.unlink()
|
||||
except ImportError:
|
||||
# If weasyprint is not installed, fall back to markdown
|
||||
with open(file_path, "w") as f:
|
||||
f.write(report.content)
|
||||
else:
|
||||
# Unsupported format, fall back to markdown
|
||||
with open(file_path, "w") as f:
|
||||
f.write(report.content)
|
||||
|
||||
return str(file_path)
|
|
@ -0,0 +1,182 @@
|
|||
"""
|
||||
Search service for the sim-search API.
|
||||
|
||||
This module provides services for search execution and result management.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db.models import Search
|
||||
|
||||
# Add sim-search to the python path
|
||||
sim_search_path = Path(settings.SIM_SEARCH_PATH)
|
||||
sys.path.append(str(sim_search_path))
|
||||
|
||||
# Import sim-search components
|
||||
from execution.search_executor import SearchExecutor
|
||||
from execution.result_collector import ResultCollector
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""
|
||||
Service for search execution and result management.
|
||||
|
||||
This class provides methods to execute searches and manage search results
|
||||
using the sim-search search execution functionality.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the search service."""
|
||||
self.search_executor = SearchExecutor()
|
||||
self.result_collector = ResultCollector()
|
||||
|
||||
async def get_available_search_engines(self) -> List[str]:
|
||||
"""
|
||||
Get a list of available search engines.
|
||||
|
||||
Returns:
|
||||
List of available search engine names
|
||||
"""
|
||||
return self.search_executor.get_available_search_engines()
|
||||
|
||||
async def execute_search(
|
||||
self,
|
||||
structured_query: Dict[str, Any],
|
||||
search_engines: Optional[List[str]] = None,
|
||||
num_results: Optional[int] = 10,
|
||||
timeout: Optional[int] = 30,
|
||||
user_id: Optional[str] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute a search with the given parameters.
|
||||
|
||||
Args:
|
||||
structured_query: Structured query
|
||||
search_engines: List of search engines to use
|
||||
num_results: Number of results to return per search engine
|
||||
timeout: Timeout in seconds
|
||||
user_id: User ID for storing the search
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Search results
|
||||
"""
|
||||
# Start timing
|
||||
start_time = time.time()
|
||||
|
||||
# Make sure structured_query is not None
|
||||
if structured_query is None:
|
||||
structured_query = {}
|
||||
|
||||
# Add search engines if not specified
|
||||
if not search_engines:
|
||||
search_engines = self.search_executor.get_available_search_engines()
|
||||
structured_query["search_engines"] = search_engines
|
||||
|
||||
# Ensure all required fields are present
|
||||
original_query = structured_query.get("original_query", "")
|
||||
|
||||
# Add raw_query field (required by search_executor)
|
||||
structured_query["raw_query"] = structured_query.get("raw_query", original_query)
|
||||
|
||||
# Add enhanced_query if missing
|
||||
if "enhanced_query" not in structured_query:
|
||||
structured_query["enhanced_query"] = original_query
|
||||
|
||||
# Make sure search_queries is not None (required by search_executor)
|
||||
if "search_queries" not in structured_query or structured_query["search_queries"] is None:
|
||||
structured_query["search_queries"] = {}
|
||||
|
||||
# Execute the search with the fixed structured_query
|
||||
search_results = self.search_executor.execute_search(
|
||||
structured_query=structured_query,
|
||||
search_engines=search_engines,
|
||||
num_results=num_results,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Process results
|
||||
processed_results = self.result_collector.process_results(
|
||||
search_results, dedup=True, max_results=None, use_reranker=True
|
||||
)
|
||||
|
||||
# Create search record if user_id and db are provided
|
||||
search_id = None
|
||||
if user_id and db:
|
||||
# Create search record
|
||||
engines_str = ",".join(search_engines) if search_engines else ""
|
||||
search = Search(
|
||||
user_id=user_id,
|
||||
query=structured_query.get("original_query", ""),
|
||||
enhanced_query=structured_query.get("enhanced_query", ""),
|
||||
query_type=structured_query.get("type", ""),
|
||||
engines=engines_str,
|
||||
results_count=len(processed_results),
|
||||
results=processed_results,
|
||||
)
|
||||
|
||||
db.add(search)
|
||||
db.commit()
|
||||
db.refresh(search)
|
||||
|
||||
search_id = search.id
|
||||
|
||||
# Format the response
|
||||
return {
|
||||
"search_id": search_id,
|
||||
"query": structured_query.get("original_query", ""),
|
||||
"enhanced_query": structured_query.get("enhanced_query", ""),
|
||||
"results": {engine: results for engine, results in search_results.items()},
|
||||
"total_results": sum(len(results) for results in search_results.values()),
|
||||
"execution_time": execution_time,
|
||||
}
|
||||
|
||||
async def get_search_results(self, search: Search) -> Dict[str, Any]:
|
||||
"""
|
||||
Get results for a specific search.
|
||||
|
||||
Args:
|
||||
search: Search record
|
||||
|
||||
Returns:
|
||||
Search results
|
||||
"""
|
||||
# Parse engines string
|
||||
engines = search.engines.split(",") if search.engines else []
|
||||
|
||||
# Get results from the database - ensure they are in correct format
|
||||
results = {}
|
||||
|
||||
# Check if results are already in engine->list format or just a flat list
|
||||
if isinstance(search.results, dict):
|
||||
# Already in the correct format
|
||||
results = search.results
|
||||
else:
|
||||
# Need to convert from flat list to engine->list format
|
||||
# Group by source
|
||||
for result in search.results:
|
||||
source = result.get("source", "unknown")
|
||||
if source not in results:
|
||||
results[source] = []
|
||||
results[source].append(result)
|
||||
|
||||
# Format the response
|
||||
return {
|
||||
"search_id": search.id,
|
||||
"query": search.query,
|
||||
"enhanced_query": search.enhanced_query,
|
||||
"results": results,
|
||||
"total_results": search.results_count,
|
||||
"execution_time": 0.0, # Not available for stored searches
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
# FastAPI and ASGI server
|
||||
fastapi==0.103.1
|
||||
uvicorn==0.23.2
|
||||
|
||||
# Database
|
||||
sqlalchemy==2.0.21
|
||||
alembic==1.12.0
|
||||
|
||||
# Authentication
|
||||
python-jose==3.3.0
|
||||
passlib==1.7.4
|
||||
bcrypt==4.0.1
|
||||
python-multipart==0.0.6
|
||||
|
||||
# Validation and serialization
|
||||
pydantic==2.4.2
|
||||
email-validator==2.0.0
|
||||
|
||||
# Testing
|
||||
pytest==7.4.2
|
||||
httpx==0.25.0
|
||||
|
||||
# Utilities
|
||||
python-dotenv==1.0.0
|
||||
aiofiles==23.2.1
|
||||
jinja2==3.1.2
|
||||
|
||||
# Report generation
|
||||
markdown==3.4.4
|
||||
weasyprint==60.1 # Optional, for PDF generation
|
|
@ -0,0 +1,56 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Run script for the sim-search API.
|
||||
|
||||
This script launches the FastAPI application using uvicorn.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import uvicorn
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Run the sim-search API")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Host to run the server on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to run the server on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reload",
|
||||
action="store_true",
|
||||
help="Enable auto-reload for development",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Run in debug mode",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the API."""
|
||||
args = parse_args()
|
||||
|
||||
print(f"Starting sim-search API on {args.host}:{args.port}...")
|
||||
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
reload=args.reload,
|
||||
log_level="debug" if args.debug else "info",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,83 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Run script for the sim-search API tests.
|
||||
|
||||
This script runs the API tests and provides a clear output of the test results.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Run the sim-search API tests")
|
||||
parser.add_argument(
|
||||
"--test-file",
|
||||
type=str,
|
||||
default="tests/test_api.py",
|
||||
help="Test file to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Enable verbose output",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--xvs",
|
||||
action="store_true",
|
||||
help="Run tests with -xvs flag (exit on first failure, verbose, show output)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--coverage",
|
||||
action="store_true",
|
||||
help="Run tests with coverage",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
def run_tests(args):
|
||||
"""Run the tests."""
|
||||
print(f"Running tests from {args.test_file}...")
|
||||
|
||||
# Build the command
|
||||
command = ["pytest"]
|
||||
|
||||
if args.xvs:
|
||||
command.append("-xvs")
|
||||
elif args.verbose:
|
||||
command.append("-v")
|
||||
|
||||
if args.coverage:
|
||||
command.extend(["--cov=app", "--cov-report=term", "--cov-report=html"])
|
||||
|
||||
command.append(args.test_file)
|
||||
|
||||
# Run the tests
|
||||
start_time = time.time()
|
||||
result = subprocess.run(command)
|
||||
end_time = time.time()
|
||||
|
||||
# Print the results
|
||||
if result.returncode == 0:
|
||||
print(f"\n✅ Tests passed in {end_time - start_time:.2f} seconds")
|
||||
else:
|
||||
print(f"\n❌ Tests failed in {end_time - start_time:.2f} seconds")
|
||||
|
||||
return result.returncode
|
||||
|
||||
def main():
|
||||
"""Main function."""
|
||||
args = parse_args()
|
||||
|
||||
# Check if the test file exists
|
||||
if not os.path.exists(args.test_file):
|
||||
print(f"Error: Test file {args.test_file} does not exist")
|
||||
return 1
|
||||
|
||||
# Run the tests
|
||||
return run_tests(args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
|
@ -0,0 +1,382 @@
|
|||
#!/bin/bash
|
||||
# Test script for the sim-search API using curl commands
|
||||
|
||||
# Configuration
|
||||
API_URL="http://localhost:8000"
|
||||
API_V1="${API_URL}/api/v1"
|
||||
TOKEN=""
|
||||
EMAIL="test@example.com"
|
||||
PASSWORD="password123"
|
||||
FULL_NAME="Test User"
|
||||
|
||||
# Colors for output
|
||||
GREEN='\033[0;32m'
|
||||
RED='\033[0;31m'
|
||||
YELLOW='\033[0;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Function to print section header
|
||||
print_header() {
|
||||
echo -e "\n${YELLOW}=== $1 ===${NC}"
|
||||
}
|
||||
|
||||
# Function to print success message
|
||||
print_success() {
|
||||
echo -e "${GREEN}✓ $1${NC}"
|
||||
}
|
||||
|
||||
# Function to print error message
|
||||
print_error() {
|
||||
echo -e "${RED}✗ $1${NC}"
|
||||
}
|
||||
|
||||
# Function to check if the API is running
|
||||
check_api() {
|
||||
print_header "Checking if API is running"
|
||||
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" ${API_URL})
|
||||
|
||||
if [ "$response" == "200" ]; then
|
||||
print_success "API is running"
|
||||
else
|
||||
print_error "API is not running. Please start the API server first."
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to register a user
|
||||
register_user() {
|
||||
print_header "Registering a user"
|
||||
|
||||
response=$(curl -s -X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"email\":\"${EMAIL}\",\"password\":\"${PASSWORD}\",\"full_name\":\"${FULL_NAME}\",\"is_active\":true,\"is_superuser\":false}" \
|
||||
${API_V1}/auth/register)
|
||||
|
||||
if echo "$response" | grep -q "email"; then
|
||||
print_success "User registered successfully"
|
||||
else
|
||||
# If user already exists, that's fine
|
||||
if echo "$response" | grep -q "already exists"; then
|
||||
print_success "User already exists, continuing with login"
|
||||
else
|
||||
print_error "Failed to register user: $response"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to get an authentication token
|
||||
get_token() {
|
||||
print_header "Getting authentication token"
|
||||
|
||||
response=$(curl -s -X POST \
|
||||
-H "Content-Type: application/x-www-form-urlencoded" \
|
||||
-d "username=${EMAIL}&password=${PASSWORD}" \
|
||||
${API_V1}/auth/token)
|
||||
|
||||
if echo "$response" | grep -q "access_token"; then
|
||||
TOKEN=$(echo "$response" | grep -o '"access_token":"[^"]*' | sed 's/"access_token":"//')
|
||||
print_success "Got authentication token"
|
||||
else
|
||||
print_error "Failed to get authentication token: $response"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to process a query
|
||||
process_query() {
|
||||
print_header "Processing a query"
|
||||
|
||||
response=$(curl -s -X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${TOKEN}" \
|
||||
-d "{\"query\":\"What are the environmental impacts of electric vehicles?\"}" \
|
||||
${API_V1}/query/process)
|
||||
|
||||
if echo "$response" | grep -q "structured_query"; then
|
||||
print_success "Query processed successfully"
|
||||
else
|
||||
print_error "Failed to process query: $response"
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to classify a query
|
||||
classify_query() {
|
||||
print_header "Classifying a query"
|
||||
|
||||
response=$(curl -s -X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${TOKEN}" \
|
||||
-d "{\"query\":\"What are the environmental impacts of electric vehicles?\"}" \
|
||||
${API_V1}/query/classify)
|
||||
|
||||
if echo "$response" | grep -q "structured_query"; then
|
||||
print_success "Query classified successfully"
|
||||
else
|
||||
print_error "Failed to classify query: $response"
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to get available search engines
|
||||
get_search_engines() {
|
||||
print_header "Getting available search engines"
|
||||
|
||||
response=$(curl -s -X GET \
|
||||
-H "Authorization: Bearer ${TOKEN}" \
|
||||
${API_V1}/search/engines)
|
||||
|
||||
if echo "$response" | grep -q "\["; then
|
||||
print_success "Got search engines successfully"
|
||||
else
|
||||
print_error "Failed to get search engines: $response"
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to execute a search
|
||||
execute_search() {
|
||||
print_header "Executing a search"
|
||||
|
||||
response=$(curl -s -X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${TOKEN}" \
|
||||
-d "{\"structured_query\":{\"original_query\":\"What are the environmental impacts of electric vehicles?\",\"enhanced_query\":\"What are the environmental impacts of electric vehicles?\",\"type\":\"factual\",\"domain\":\"environmental\"},\"search_engines\":[\"google\",\"arxiv\"],\"num_results\":5,\"timeout\":30}" \
|
||||
${API_V1}/search/execute)
|
||||
|
||||
if echo "$response" | grep -q "search_id"; then
|
||||
SEARCH_ID=$(echo "$response" | grep -o '"search_id":"[^"]*' | sed 's/"search_id":"//')
|
||||
print_success "Search executed successfully with ID: $SEARCH_ID"
|
||||
else
|
||||
print_error "Failed to execute search: $response"
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to get search history
|
||||
get_search_history() {
|
||||
print_header "Getting search history"
|
||||
|
||||
response=$(curl -s -X GET \
|
||||
-H "Authorization: Bearer ${TOKEN}" \
|
||||
${API_V1}/search/history)
|
||||
|
||||
if echo "$response" | grep -q "searches"; then
|
||||
print_success "Got search history successfully"
|
||||
else
|
||||
print_error "Failed to get search history: $response"
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to get search results
|
||||
get_search_results() {
|
||||
print_header "Getting search results"
|
||||
|
||||
if [ -z "$SEARCH_ID" ]; then
|
||||
print_error "No search ID available. Please execute a search first."
|
||||
return
|
||||
fi
|
||||
|
||||
response=$(curl -s -X GET \
|
||||
-H "Authorization: Bearer ${TOKEN}" \
|
||||
${API_V1}/search/${SEARCH_ID})
|
||||
|
||||
if echo "$response" | grep -q "search_id"; then
|
||||
print_success "Got search results successfully"
|
||||
else
|
||||
print_error "Failed to get search results: $response"
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to generate a report
|
||||
generate_report() {
|
||||
print_header "Generating a report"
|
||||
|
||||
if [ -z "$SEARCH_ID" ]; then
|
||||
print_error "No search ID available. Please execute a search first."
|
||||
return
|
||||
fi
|
||||
|
||||
response=$(curl -s -X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${TOKEN}" \
|
||||
-d "{\"search_id\":\"${SEARCH_ID}\",\"query\":\"What are the environmental impacts of electric vehicles?\",\"detail_level\":\"standard\",\"query_type\":\"factual\",\"model\":\"llama-3.1-8b-instant\"}" \
|
||||
${API_V1}/report/generate)
|
||||
|
||||
if echo "$response" | grep -q "id"; then
|
||||
REPORT_ID=$(echo "$response" | grep -o '"id":"[^"]*' | sed 's/"id":"//')
|
||||
print_success "Report generated successfully with ID: $REPORT_ID"
|
||||
else
|
||||
print_error "Failed to generate report: $response"
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to get report progress
|
||||
get_report_progress() {
|
||||
print_header "Getting report progress"
|
||||
|
||||
if [ -z "$REPORT_ID" ]; then
|
||||
print_error "No report ID available. Please generate a report first."
|
||||
return
|
||||
fi
|
||||
|
||||
response=$(curl -s -X GET \
|
||||
-H "Authorization: Bearer ${TOKEN}" \
|
||||
${API_V1}/report/${REPORT_ID}/progress)
|
||||
|
||||
if echo "$response" | grep -q "progress"; then
|
||||
print_success "Got report progress successfully"
|
||||
else
|
||||
print_error "Failed to get report progress: $response"
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to get report list
|
||||
get_report_list() {
|
||||
print_header "Getting report list"
|
||||
|
||||
response=$(curl -s -X GET \
|
||||
-H "Authorization: Bearer ${TOKEN}" \
|
||||
${API_V1}/report/list)
|
||||
|
||||
if echo "$response" | grep -q "reports"; then
|
||||
print_success "Got report list successfully"
|
||||
else
|
||||
print_error "Failed to get report list: $response"
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to get a specific report
|
||||
get_report() {
|
||||
print_header "Getting a specific report"
|
||||
|
||||
if [ -z "$REPORT_ID" ]; then
|
||||
print_error "No report ID available. Please generate a report first."
|
||||
return
|
||||
fi
|
||||
|
||||
response=$(curl -s -X GET \
|
||||
-H "Authorization: Bearer ${TOKEN}" \
|
||||
${API_V1}/report/${REPORT_ID})
|
||||
|
||||
if echo "$response" | grep -q "id"; then
|
||||
print_success "Got report successfully"
|
||||
else
|
||||
print_error "Failed to get report: $response"
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to download a report
|
||||
download_report() {
|
||||
print_header "Downloading a report"
|
||||
|
||||
if [ -z "$REPORT_ID" ]; then
|
||||
print_error "No report ID available. Please generate a report first."
|
||||
return
|
||||
fi
|
||||
|
||||
response=$(curl -s -X GET \
|
||||
-H "Authorization: Bearer ${TOKEN}" \
|
||||
${API_V1}/report/${REPORT_ID}/download?format=markdown)
|
||||
|
||||
if [ -n "$response" ]; then
|
||||
print_success "Downloaded report successfully"
|
||||
else
|
||||
print_error "Failed to download report"
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to delete a report
|
||||
delete_report() {
|
||||
print_header "Deleting a report"
|
||||
|
||||
if [ -z "$REPORT_ID" ]; then
|
||||
print_error "No report ID available. Please generate a report first."
|
||||
return
|
||||
fi
|
||||
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE \
|
||||
-H "Authorization: Bearer ${TOKEN}" \
|
||||
${API_V1}/report/${REPORT_ID})
|
||||
|
||||
if [ "$response" == "204" ]; then
|
||||
print_success "Report deleted successfully"
|
||||
else
|
||||
print_error "Failed to delete report: $response"
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to delete a search
|
||||
delete_search() {
|
||||
print_header "Deleting a search"
|
||||
|
||||
if [ -z "$SEARCH_ID" ]; then
|
||||
print_error "No search ID available. Please execute a search first."
|
||||
return
|
||||
fi
|
||||
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE \
|
||||
-H "Authorization: Bearer ${TOKEN}" \
|
||||
${API_V1}/search/${SEARCH_ID})
|
||||
|
||||
if [ "$response" == "204" ]; then
|
||||
print_success "Search deleted successfully"
|
||||
else
|
||||
print_error "Failed to delete search: $response"
|
||||
fi
|
||||
}
|
||||
|
||||
# Main function
|
||||
main() {
|
||||
echo "Starting API tests..."
|
||||
|
||||
# Check if the API is running
|
||||
check_api
|
||||
|
||||
# Register a user
|
||||
register_user
|
||||
|
||||
# Get an authentication token
|
||||
get_token
|
||||
|
||||
# Process a query
|
||||
process_query
|
||||
|
||||
# Classify a query
|
||||
classify_query
|
||||
|
||||
# Get available search engines
|
||||
get_search_engines
|
||||
|
||||
# Execute a search
|
||||
execute_search
|
||||
|
||||
# Get search history
|
||||
get_search_history
|
||||
|
||||
# Get search results
|
||||
get_search_results
|
||||
|
||||
# Generate a report
|
||||
generate_report
|
||||
|
||||
# Get report progress
|
||||
get_report_progress
|
||||
|
||||
# Get report list
|
||||
get_report_list
|
||||
|
||||
# Get a specific report
|
||||
get_report
|
||||
|
||||
# Download a report
|
||||
download_report
|
||||
|
||||
# Delete a report
|
||||
delete_report
|
||||
|
||||
# Delete a search
|
||||
delete_search
|
||||
|
||||
echo -e "\n${GREEN}All tests completed!${NC}"
|
||||
}
|
||||
|
||||
# Run the main function
|
||||
main
|
|
@ -0,0 +1,101 @@
|
|||
# Sim-Search API Tests
|
||||
|
||||
This directory contains tests for the Sim-Search API.
|
||||
|
||||
## Test Files
|
||||
|
||||
- `test_api.py`: Tests the core functionality of the API, including authentication, query processing, search execution, and report generation.
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Using pytest directly
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
pytest
|
||||
|
||||
# Run a specific test file
|
||||
pytest tests/test_api.py
|
||||
|
||||
# Run tests with verbose output
|
||||
pytest -v tests/test_api.py
|
||||
|
||||
# Run tests with verbose output and exit on first failure
|
||||
pytest -xvs tests/test_api.py
|
||||
|
||||
# Run tests with coverage report
|
||||
pytest --cov=app --cov-report=term --cov-report=html tests/test_api.py
|
||||
```
|
||||
|
||||
### Using the run_tests.py script
|
||||
|
||||
We provide a convenient script to run the tests with various options:
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
python run_tests.py
|
||||
|
||||
# Run with verbose output
|
||||
python run_tests.py --verbose
|
||||
|
||||
# Run with -xvs flag (exit on first failure, verbose, show output)
|
||||
python run_tests.py --xvs
|
||||
|
||||
# Run with coverage report
|
||||
python run_tests.py --coverage
|
||||
|
||||
# Run a specific test file
|
||||
python run_tests.py --test-file tests/test_api.py
|
||||
```
|
||||
|
||||
### Using the test_api_curl.sh script
|
||||
|
||||
For manual testing of the API endpoints using curl commands:
|
||||
|
||||
```bash
|
||||
# Make the script executable
|
||||
chmod +x test_api_curl.sh
|
||||
|
||||
# Run the script
|
||||
./test_api_curl.sh
|
||||
```
|
||||
|
||||
This script will test all the API endpoints in sequence, including:
|
||||
- Authentication (register, login)
|
||||
- Query processing and classification
|
||||
- Search execution and retrieval
|
||||
- Report generation and management
|
||||
|
||||
## Test Database
|
||||
|
||||
The tests use a separate SQLite database (`test.db`) to avoid affecting the production database. This database is created and destroyed during the test run.
|
||||
|
||||
## Test User
|
||||
|
||||
The tests create a test user with the following credentials:
|
||||
- Email: test@example.com
|
||||
- Password: password123
|
||||
- Full Name: Test User
|
||||
|
||||
## Test Coverage
|
||||
|
||||
To generate a test coverage report:
|
||||
|
||||
```bash
|
||||
pytest --cov=app --cov-report=term --cov-report=html tests/
|
||||
```
|
||||
|
||||
This will generate a coverage report in the terminal and an HTML report in the `htmlcov` directory.
|
||||
|
||||
## Continuous Integration
|
||||
|
||||
These tests can be integrated into a CI/CD pipeline to ensure that the API is working correctly before deployment.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If you encounter issues with the tests:
|
||||
|
||||
1. Make sure the API server is not running when running the tests, as they will start their own instance.
|
||||
2. Check that the test database is not locked by another process.
|
||||
3. Ensure that all dependencies are installed (`pip install -r requirements.txt`).
|
||||
4. If you're getting authentication errors, make sure the JWT secret key is set correctly in the test environment.
|
|
@ -0,0 +1,480 @@
|
|||
"""
|
||||
Test script for the sim-search API.
|
||||
|
||||
This script tests the core functionality of the API, including authentication,
|
||||
query processing, search execution, and report generation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
# Add the project root directory to the Python path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from app.main import app
|
||||
from app.db.session import Base
|
||||
from app.db.models import User
|
||||
from app.core.security import get_password_hash
|
||||
from app.core.config import settings
|
||||
from app.api.dependencies import get_db
|
||||
|
||||
# Create a test database
|
||||
TEST_SQLALCHEMY_DATABASE_URI = "sqlite:///./test.db"
|
||||
engine = create_engine(
|
||||
TEST_SQLALCHEMY_DATABASE_URI,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# Override the get_db dependency
|
||||
def override_get_db():
|
||||
try:
|
||||
db = TestingSessionLocal()
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
# Create a test client
|
||||
client = TestClient(app)
|
||||
|
||||
# Test user credentials
|
||||
test_user_email = "test@example.com"
|
||||
test_user_password = "password123"
|
||||
test_user_full_name = "Test User"
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def setup_database():
|
||||
"""Set up the test database."""
|
||||
# Clean up any existing database
|
||||
if os.path.exists("./test.db"):
|
||||
os.remove("./test.db")
|
||||
|
||||
# Create tables
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# Create a test user
|
||||
db = TestingSessionLocal()
|
||||
user = User(
|
||||
email=test_user_email,
|
||||
hashed_password=get_password_hash(test_user_password),
|
||||
full_name=test_user_full_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
db.add(user)
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"Error creating test user: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
yield
|
||||
|
||||
# Clean up
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
if os.path.exists("./test.db"):
|
||||
os.remove("./test.db")
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def auth_token(setup_database):
|
||||
"""Get an authentication token for the test user."""
|
||||
response = client.post(
|
||||
f"{settings.API_V1_STR}/auth/token",
|
||||
data={"username": test_user_email, "password": test_user_password},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
token_data = response.json()
|
||||
assert "access_token" in token_data
|
||||
assert token_data["token_type"] == "bearer"
|
||||
return token_data["access_token"]
|
||||
|
||||
def test_root():
|
||||
"""Test the root endpoint."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "online"
|
||||
assert data["version"] == settings.VERSION
|
||||
assert data["project"] == settings.PROJECT_NAME
|
||||
assert data["docs"] == "/docs"
|
||||
|
||||
def test_auth_token(setup_database):
|
||||
"""Test getting an authentication token."""
|
||||
response = client.post(
|
||||
f"{settings.API_V1_STR}/auth/token",
|
||||
data={"username": test_user_email, "password": test_user_password},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
token_data = response.json()
|
||||
assert "access_token" in token_data
|
||||
assert token_data["token_type"] == "bearer"
|
||||
|
||||
def test_auth_token_invalid_credentials(setup_database):
|
||||
"""Test getting an authentication token with invalid credentials."""
|
||||
response = client.post(
|
||||
f"{settings.API_V1_STR}/auth/token",
|
||||
data={"username": test_user_email, "password": "wrong_password"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Incorrect email or password"
|
||||
|
||||
def test_register_user(setup_database):
|
||||
"""Test registering a new user."""
|
||||
response = client.post(
|
||||
f"{settings.API_V1_STR}/auth/register",
|
||||
json={
|
||||
"email": "new_user@example.com",
|
||||
"password": "password123",
|
||||
"full_name": "New User",
|
||||
"is_active": True,
|
||||
"is_superuser": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
user_data = response.json()
|
||||
assert user_data["email"] == "new_user@example.com"
|
||||
assert user_data["full_name"] == "New User"
|
||||
assert user_data["is_active"] == True
|
||||
assert user_data["is_superuser"] == False
|
||||
|
||||
def test_register_existing_user(setup_database):
|
||||
"""Test registering a user with an existing email."""
|
||||
response = client.post(
|
||||
f"{settings.API_V1_STR}/auth/register",
|
||||
json={
|
||||
"email": test_user_email,
|
||||
"password": "password123",
|
||||
"full_name": "Duplicate User",
|
||||
"is_active": True,
|
||||
"is_superuser": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "A user with this email already exists"
|
||||
|
||||
def test_process_query(auth_token):
|
||||
"""Test processing a query."""
|
||||
response = client.post(
|
||||
f"{settings.API_V1_STR}/query/process",
|
||||
json={"query": "What are the environmental impacts of electric vehicles?"},
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["original_query"] == "What are the environmental impacts of electric vehicles?"
|
||||
assert "structured_query" in data
|
||||
assert data["structured_query"]["original_query"] == "What are the environmental impacts of electric vehicles?"
|
||||
|
||||
def test_classify_query(auth_token):
|
||||
"""Test classifying a query."""
|
||||
response = client.post(
|
||||
f"{settings.API_V1_STR}/query/classify",
|
||||
json={"query": "What are the environmental impacts of electric vehicles?"},
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["original_query"] == "What are the environmental impacts of electric vehicles?"
|
||||
assert "structured_query" in data
|
||||
assert data["structured_query"]["original_query"] == "What are the environmental impacts of electric vehicles?"
|
||||
assert "type" in data["structured_query"]
|
||||
assert "domain" in data["structured_query"]
|
||||
|
||||
def test_get_available_search_engines(auth_token):
|
||||
"""Test getting available search engines."""
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/search/engines",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
engines = response.json()
|
||||
assert isinstance(engines, list)
|
||||
assert len(engines) > 0
|
||||
|
||||
def test_execute_search(auth_token):
|
||||
"""Test executing a search."""
|
||||
response = client.post(
|
||||
f"{settings.API_V1_STR}/search/execute",
|
||||
json={
|
||||
"structured_query": {
|
||||
"original_query": "What are the environmental impacts of electric vehicles?",
|
||||
"enhanced_query": "What are the environmental impacts of electric vehicles?",
|
||||
"type": "factual",
|
||||
"domain": "environmental",
|
||||
},
|
||||
"search_engines": ["google", "arxiv"],
|
||||
"num_results": 5,
|
||||
"timeout": 30,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "search_id" in data
|
||||
assert data["query"] == "What are the environmental impacts of electric vehicles?"
|
||||
assert "results" in data
|
||||
assert "total_results" in data
|
||||
assert "execution_time" in data
|
||||
|
||||
def test_get_search_history(auth_token):
|
||||
"""Test getting search history."""
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/search/history",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "searches" in data
|
||||
assert "total" in data
|
||||
assert isinstance(data["searches"], list)
|
||||
assert isinstance(data["total"], int)
|
||||
|
||||
def test_get_search_results(auth_token):
|
||||
"""Test getting search results."""
|
||||
# First, execute a search to get a search_id
|
||||
response = client.post(
|
||||
f"{settings.API_V1_STR}/search/execute",
|
||||
json={
|
||||
"structured_query": {
|
||||
"original_query": "What are the economic benefits of electric vehicles?",
|
||||
"enhanced_query": "What are the economic benefits of electric vehicles?",
|
||||
"type": "factual",
|
||||
"domain": "economic",
|
||||
},
|
||||
"search_engines": ["google", "arxiv"],
|
||||
"num_results": 5,
|
||||
"timeout": 30,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
search_data = response.json()
|
||||
search_id = search_data["search_id"]
|
||||
|
||||
# Now get the search results
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/search/{search_id}",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["search_id"] == search_id
|
||||
assert data["query"] == "What are the economic benefits of electric vehicles?"
|
||||
assert "results" in data
|
||||
assert "total_results" in data
|
||||
|
||||
def test_generate_report(auth_token):
|
||||
"""Test generating a report."""
|
||||
# First, execute a search to get a search_id
|
||||
response = client.post(
|
||||
f"{settings.API_V1_STR}/search/execute",
|
||||
json={
|
||||
"structured_query": {
|
||||
"original_query": "What are the environmental and economic impacts of electric vehicles?",
|
||||
"enhanced_query": "What are the environmental and economic impacts of electric vehicles?",
|
||||
"type": "comparative",
|
||||
"domain": "environmental,economic",
|
||||
},
|
||||
"search_engines": ["google", "arxiv"],
|
||||
"num_results": 5,
|
||||
"timeout": 30,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
search_data = response.json()
|
||||
search_id = search_data["search_id"]
|
||||
|
||||
# Now generate a report
|
||||
response = client.post(
|
||||
f"{settings.API_V1_STR}/report/generate",
|
||||
json={
|
||||
"search_id": search_id,
|
||||
"query": "What are the environmental and economic impacts of electric vehicles?",
|
||||
"detail_level": "standard",
|
||||
"query_type": "comparative",
|
||||
"model": "llama-3.1-8b-instant",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert data["title"].startswith("Report: What are the environmental and economic impacts")
|
||||
assert data["detail_level"] == "standard"
|
||||
assert data["query_type"] == "comparative"
|
||||
assert data["model_used"] == "llama-3.1-8b-instant"
|
||||
|
||||
# Get the report progress
|
||||
report_id = data["id"]
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/report/{report_id}/progress",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
progress_data = response.json()
|
||||
assert progress_data["report_id"] == report_id
|
||||
assert "progress" in progress_data
|
||||
assert "status" in progress_data
|
||||
|
||||
def test_get_report_list(auth_token):
|
||||
"""Test getting a list of reports."""
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/report/list",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "reports" in data
|
||||
assert "total" in data
|
||||
assert isinstance(data["reports"], list)
|
||||
assert isinstance(data["total"], int)
|
||||
|
||||
def test_get_report(auth_token):
|
||||
"""Test getting a specific report."""
|
||||
# First, get the list of reports to get a report_id
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/report/list",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
list_data = response.json()
|
||||
assert len(list_data["reports"]) > 0
|
||||
report_id = list_data["reports"][0]["id"]
|
||||
|
||||
# Now get the specific report
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/report/{report_id}",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == report_id
|
||||
assert "title" in data
|
||||
assert "content" in data
|
||||
assert "detail_level" in data
|
||||
assert "query_type" in data
|
||||
assert "model_used" in data
|
||||
|
||||
def test_download_report(auth_token):
|
||||
"""Test downloading a report."""
|
||||
# First, execute a search to get a search_id
|
||||
response = client.post(
|
||||
f"{settings.API_V1_STR}/search/execute",
|
||||
json={
|
||||
"structured_query": {
|
||||
"original_query": "What are the environmental impacts of electric vehicles?",
|
||||
"enhanced_query": "What are the environmental impacts of electric vehicles?",
|
||||
"type": "comparative",
|
||||
"domain": "environmental,economic",
|
||||
},
|
||||
"search_engines": ["serper"],
|
||||
"num_results": 2,
|
||||
"timeout": 10,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
search_data = response.json()
|
||||
search_id = search_data["search_id"]
|
||||
|
||||
# Now generate a report
|
||||
response = client.post(
|
||||
f"{settings.API_V1_STR}/report/generate",
|
||||
json={
|
||||
"search_id": search_id,
|
||||
"query": "What are the environmental impacts of electric vehicles?",
|
||||
"detail_level": "brief",
|
||||
"query_type": "comparative",
|
||||
"model": "llama-3.1-8b-instant",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
report_data = response.json()
|
||||
report_id = report_data["id"]
|
||||
|
||||
# Now download the report in markdown format
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/report/{report_id}/download?format=markdown",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/octet-stream"
|
||||
assert response.headers["content-disposition"] == f'attachment; filename="report_{report_id}.markdown"'
|
||||
|
||||
# Now download the report in HTML format
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/report/{report_id}/download?format=html",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/octet-stream"
|
||||
assert response.headers["content-disposition"] == f'attachment; filename="report_{report_id}.html"'
|
||||
|
||||
def test_delete_report(auth_token):
|
||||
"""Test deleting a report."""
|
||||
# First, get the list of reports to get a report_id
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/report/list",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
list_data = response.json()
|
||||
assert len(list_data["reports"]) > 0
|
||||
report_id = list_data["reports"][0]["id"]
|
||||
|
||||
# Now delete the report
|
||||
response = client.delete(
|
||||
f"{settings.API_V1_STR}/report/{report_id}",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
# Verify that the report is deleted
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/report/{report_id}",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_delete_search(auth_token):
|
||||
"""Test deleting a search."""
|
||||
# First, get the list of searches to get a search_id
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/search/history",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
list_data = response.json()
|
||||
assert len(list_data["searches"]) > 0
|
||||
search_id = list_data["searches"][0]["id"]
|
||||
|
||||
# Now delete the search
|
||||
response = client.delete(
|
||||
f"{settings.API_V1_STR}/search/{search_id}",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
# Verify that the search is deleted
|
||||
response = client.get(
|
||||
f"{settings.API_V1_STR}/search/{search_id}",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-xvs", __file__])
|
|
@ -0,0 +1,100 @@
|
|||
"""
|
||||
Integration test for query classification and search execution.
|
||||
|
||||
This test demonstrates how the LLM-based query domain classification
|
||||
affects the search engines selected for different types of queries.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Any, List
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
||||
|
||||
from query.query_processor import get_query_processor
|
||||
from execution.search_executor import get_search_executor
|
||||
|
||||
|
||||
async def test_query_classification_search_integration():
|
||||
"""Test how query classification affects search engine selection."""
|
||||
query_processor = get_query_processor()
|
||||
search_executor = get_search_executor()
|
||||
|
||||
# Test queries for different domains
|
||||
test_queries = [
|
||||
{
|
||||
"description": "Academic query about quantum computing",
|
||||
"query": "What are the latest theoretical advances in quantum computing algorithms?"
|
||||
},
|
||||
{
|
||||
"description": "Code query about implementing a neural network",
|
||||
"query": "How do I implement a convolutional neural network in TensorFlow?"
|
||||
},
|
||||
{
|
||||
"description": "Current events query about economic policy",
|
||||
"query": "What are the recent changes to Federal Reserve interest rates and their economic impact?"
|
||||
},
|
||||
{
|
||||
"description": "Mixed query with academic and code aspects",
|
||||
"query": "How are transformer models being implemented for natural language processing tasks?"
|
||||
}
|
||||
]
|
||||
|
||||
results = []
|
||||
|
||||
for test_case in test_queries:
|
||||
query = test_case["query"]
|
||||
description = test_case["description"]
|
||||
|
||||
print(f"\n=== Testing: {description} ===")
|
||||
print(f"Query: {query}")
|
||||
|
||||
# Process the query
|
||||
structured_query = await query_processor.process_query(query)
|
||||
|
||||
# Get domain classification results
|
||||
domain = structured_query.get('domain', 'general')
|
||||
domain_confidence = structured_query.get('domain_confidence', 0.0)
|
||||
is_academic = structured_query.get('is_academic', False)
|
||||
is_code = structured_query.get('is_code', False)
|
||||
is_current_events = structured_query.get('is_current_events', False)
|
||||
|
||||
print(f"Domain: {domain} (confidence: {domain_confidence})")
|
||||
print(f"Is academic: {is_academic}")
|
||||
print(f"Is code: {is_code}")
|
||||
print(f"Is current events: {is_current_events}")
|
||||
|
||||
# Execute search with default search engines based on classification
|
||||
search_results = await search_executor.execute_search(structured_query)
|
||||
|
||||
# Get the search engines that were selected
|
||||
selected_engines = list(search_results.keys())
|
||||
print(f"Selected search engines: {selected_engines}")
|
||||
|
||||
# Store the results
|
||||
result = {
|
||||
"query": query,
|
||||
"description": description,
|
||||
"domain": domain,
|
||||
"domain_confidence": domain_confidence,
|
||||
"is_academic": is_academic,
|
||||
"is_code": is_code,
|
||||
"is_current_events": is_current_events,
|
||||
"selected_engines": selected_engines,
|
||||
"num_results_per_engine": {engine: len(results) for engine, results in search_results.items()}
|
||||
}
|
||||
|
||||
results.append(result)
|
||||
|
||||
# Save results to a file
|
||||
with open('query_classification_search_results.json', 'w') as f:
|
||||
json.dump(results, indent=2, fp=f)
|
||||
|
||||
print(f"\nResults saved to query_classification_search_results.json")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_query_classification_search_integration())
|
|
@ -0,0 +1,209 @@
|
|||
"""
|
||||
Test the query domain classification functionality.
|
||||
|
||||
This script tests the new LLM-based query domain classification functionality
|
||||
to ensure it correctly classifies queries into academic, code, current_events,
|
||||
and general categories.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Any, List
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
||||
|
||||
from query.llm_interface import get_llm_interface
|
||||
from query.query_processor import get_query_processor
|
||||
|
||||
|
||||
async def test_classify_query_domain():
|
||||
"""Test the classify_query_domain function."""
|
||||
llm_interface = get_llm_interface()
|
||||
|
||||
test_queries = [
|
||||
# Academic queries
|
||||
"What are the technological, economic, and social implications of large language models in today's society?",
|
||||
"What is the current state of research on quantum computing algorithms?",
|
||||
"How has climate change affected biodiversity in marine ecosystems?",
|
||||
|
||||
# Code queries
|
||||
"How do I implement a transformer model in PyTorch for text classification?",
|
||||
"What's the best way to optimize a recursive function in Python?",
|
||||
"Explain how to use React hooks with TypeScript",
|
||||
|
||||
# Current events queries
|
||||
"What are the latest developments in the Ukraine conflict?",
|
||||
"How has the Federal Reserve's recent interest rate decision affected the stock market?",
|
||||
"What were the outcomes of the recent climate summit?",
|
||||
|
||||
# Mixed or general queries
|
||||
"How are LLMs being used to detect and prevent cyber attacks?",
|
||||
"What are the best practices for remote work?",
|
||||
"Compare electric vehicles to traditional gas-powered cars"
|
||||
]
|
||||
|
||||
results = []
|
||||
|
||||
for query in test_queries:
|
||||
print(f"\nClassifying query: {query}")
|
||||
domain_classification = await llm_interface.classify_query_domain(query)
|
||||
|
||||
print(f"Primary type: {domain_classification.get('primary_type')} (confidence: {domain_classification.get('confidence')})")
|
||||
|
||||
if domain_classification.get('secondary_types'):
|
||||
for sec_type in domain_classification.get('secondary_types'):
|
||||
print(f"Secondary type: {sec_type['type']} (confidence: {sec_type['confidence']})")
|
||||
|
||||
print(f"Reasoning: {domain_classification.get('reasoning', 'None provided')}")
|
||||
|
||||
results.append({
|
||||
'query': query,
|
||||
'classification': domain_classification
|
||||
})
|
||||
|
||||
# Save results to a file
|
||||
with open('domain_classification_results.json', 'w') as f:
|
||||
json.dump(results, indent=2, fp=f)
|
||||
|
||||
print(f"\nResults saved to domain_classification_results.json")
|
||||
|
||||
|
||||
async def test_query_processor_with_domain_classification():
|
||||
"""Test the query processor with the new domain classification."""
|
||||
query_processor = get_query_processor()
|
||||
|
||||
test_queries = [
|
||||
"What are the technological implications of large language models?",
|
||||
"How do I implement a transformer model in PyTorch?",
|
||||
"What are the latest developments in the Ukraine conflict?",
|
||||
"How are LLMs being used to detect cyber attacks?"
|
||||
]
|
||||
|
||||
results = []
|
||||
|
||||
for query in test_queries:
|
||||
print(f"\nProcessing query: {query}")
|
||||
structured_query = await query_processor.process_query(query)
|
||||
|
||||
print(f"Domain: {structured_query.get('domain')} (confidence: {structured_query.get('domain_confidence')})")
|
||||
print(f"Is academic: {structured_query.get('is_academic')}")
|
||||
print(f"Is code: {structured_query.get('is_code')}")
|
||||
print(f"Is current events: {structured_query.get('is_current_events')}")
|
||||
|
||||
if structured_query.get('secondary_domains'):
|
||||
for domain in structured_query.get('secondary_domains'):
|
||||
print(f"Secondary domain: {domain['type']} (confidence: {domain['confidence']})")
|
||||
|
||||
print(f"Reasoning: {structured_query.get('classification_reasoning', 'None provided')}")
|
||||
|
||||
results.append({
|
||||
'query': query,
|
||||
'structured_query': {
|
||||
'domain': structured_query.get('domain'),
|
||||
'domain_confidence': structured_query.get('domain_confidence'),
|
||||
'is_academic': structured_query.get('is_academic'),
|
||||
'is_code': structured_query.get('is_code'),
|
||||
'is_current_events': structured_query.get('is_current_events'),
|
||||
'secondary_domains': structured_query.get('secondary_domains'),
|
||||
'classification_reasoning': structured_query.get('classification_reasoning')
|
||||
}
|
||||
})
|
||||
|
||||
# Save results to a file
|
||||
with open('query_processor_domain_results.json', 'w') as f:
|
||||
json.dump(results, indent=2, fp=f)
|
||||
|
||||
print(f"\nResults saved to query_processor_domain_results.json")
|
||||
|
||||
|
||||
async def compare_with_keyword_classification():
|
||||
"""Compare LLM-based classification with keyword-based classification."""
|
||||
query_processor = get_query_processor()
|
||||
|
||||
# Monkey patch the query processor to use keyword-based classification
|
||||
original_structure_query_with_llm = query_processor._structure_query_with_llm
|
||||
|
||||
# Test queries that might be challenging for keyword-based approach
|
||||
test_queries = [
|
||||
"How do language models work internally?", # Could be academic or code
|
||||
"What are the best machine learning models for text generation?", # "models" could trigger code
|
||||
"How has ChatGPT changed the AI landscape?", # Recent but academic topic
|
||||
"What techniques help in understanding neural networks?", # Could be academic or code
|
||||
"How are transformers used in NLP applications?", # Ambiguous - could mean electrical transformers or ML
|
||||
]
|
||||
|
||||
results = []
|
||||
|
||||
for query in test_queries:
|
||||
print(f"\nProcessing query with both methods: {query}")
|
||||
|
||||
# First, use LLM-based classification (normal operation)
|
||||
structured_query_llm = await query_processor.process_query(query)
|
||||
|
||||
# Now, force keyword-based classification by monkey patching
|
||||
query_processor._structure_query_with_llm = query_processor._structure_query
|
||||
structured_query_keyword = await query_processor.process_query(query)
|
||||
|
||||
# Restore original method
|
||||
query_processor._structure_query_with_llm = original_structure_query_with_llm
|
||||
|
||||
# Compare results
|
||||
print(f"LLM Classification:")
|
||||
print(f" Domain: {structured_query_llm.get('domain')}")
|
||||
print(f" Is academic: {structured_query_llm.get('is_academic')}")
|
||||
print(f" Is code: {structured_query_llm.get('is_code')}")
|
||||
print(f" Is current events: {structured_query_llm.get('is_current_events')}")
|
||||
|
||||
print(f"Keyword Classification:")
|
||||
print(f" Is academic: {structured_query_keyword.get('is_academic')}")
|
||||
print(f" Is code: {structured_query_keyword.get('is_code')}")
|
||||
print(f" Is current events: {structured_query_keyword.get('is_current_events')}")
|
||||
|
||||
results.append({
|
||||
'query': query,
|
||||
'llm_classification': {
|
||||
'domain': structured_query_llm.get('domain'),
|
||||
'is_academic': structured_query_llm.get('is_academic'),
|
||||
'is_code': structured_query_llm.get('is_code'),
|
||||
'is_current_events': structured_query_llm.get('is_current_events')
|
||||
},
|
||||
'keyword_classification': {
|
||||
'is_academic': structured_query_keyword.get('is_academic'),
|
||||
'is_code': structured_query_keyword.get('is_code'),
|
||||
'is_current_events': structured_query_keyword.get('is_current_events')
|
||||
}
|
||||
})
|
||||
|
||||
# Save comparison results to a file
|
||||
with open('classification_comparison_results.json', 'w') as f:
|
||||
json.dump(results, indent=2, fp=f)
|
||||
|
||||
print(f"\nComparison results saved to classification_comparison_results.json")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run tests for query domain classification."""
|
||||
# Choose which test to run
|
||||
test_type = 1 # Change to 1, 2, or 3 to run different tests
|
||||
|
||||
if test_type == 1:
|
||||
print("=== Testing classify_query_domain function ===")
|
||||
await test_classify_query_domain()
|
||||
elif test_type == 2:
|
||||
print("=== Testing query processor with domain classification ===")
|
||||
await test_query_processor_with_domain_classification()
|
||||
elif test_type == 3:
|
||||
print("=== Comparing LLM and keyword classifications ===")
|
||||
await compare_with_keyword_classification()
|
||||
else:
|
||||
print("=== Running all tests ===")
|
||||
await test_classify_query_domain()
|
||||
await test_query_processor_with_domain_classification()
|
||||
await compare_with_keyword_classification()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
|
@ -35,8 +35,21 @@ class GradioInterface:
|
|||
self.sub_question_executor = get_sub_question_executor()
|
||||
self.results_dir = Path(__file__).parent.parent / "results"
|
||||
self.results_dir.mkdir(exist_ok=True)
|
||||
self.reports_dir = Path(__file__).parent.parent
|
||||
|
||||
# Create a dedicated reports directory with subdirectories
|
||||
self.reports_dir = Path(__file__).parent.parent / "reports"
|
||||
self.reports_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Create daily subdirectory for organization
|
||||
self.reports_daily_dir = self.reports_dir / datetime.now().strftime("%Y-%m-%d")
|
||||
self.reports_daily_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Create a metadata file to track reports
|
||||
self.reports_metadata_file = self.reports_dir / "reports_metadata.json"
|
||||
if not self.reports_metadata_file.exists():
|
||||
with open(self.reports_metadata_file, "w") as f:
|
||||
json.dump({"reports": []}, f, indent=2)
|
||||
|
||||
self.detail_level_manager = get_report_detail_level_manager()
|
||||
self.config = Config()
|
||||
|
||||
|
@ -206,7 +219,7 @@ class GradioInterface:
|
|||
Path to the generated report
|
||||
"""
|
||||
try:
|
||||
# Create a timestamped output file
|
||||
# Create a timestamped output file in the daily directory
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_suffix = ""
|
||||
|
||||
|
@ -225,7 +238,12 @@ class GradioInterface:
|
|||
print(f"Extracted model name: {custom_model}")
|
||||
print(f"Using model suffix: {model_suffix}")
|
||||
|
||||
output_file = self.reports_dir / f"report_{timestamp}{model_suffix}.md"
|
||||
# Create a unique report ID
|
||||
import hashlib
|
||||
report_id = f"{timestamp}_{hashlib.md5(query.encode()).hexdigest()[:8]}"
|
||||
|
||||
# Define the output file path in the daily directory
|
||||
output_file = self.reports_daily_dir / f"report_{report_id}{model_suffix}.md"
|
||||
|
||||
# Get detail level configuration
|
||||
config = self.detail_level_manager.get_detail_level_config(detail_level)
|
||||
|
@ -256,10 +274,6 @@ class GradioInterface:
|
|||
default_model = detail_config.get("model", "unknown")
|
||||
print(f"Default model for {detail_level} detail level: {default_model}")
|
||||
|
||||
# First set the detail level, which will set the default model for this detail level
|
||||
self.report_generator.set_detail_level(detail_level)
|
||||
print(f"After setting detail level, report generator model is: {self.report_generator.model_name}")
|
||||
|
||||
# Then explicitly override with custom model if provided
|
||||
if custom_model:
|
||||
# Extract the actual model name from the display name format
|
||||
|
@ -523,6 +537,19 @@ class GradioInterface:
|
|||
|
||||
print(f"Report saved to: {output_file}")
|
||||
|
||||
# Update report metadata
|
||||
self._update_report_metadata(report_id, {
|
||||
"id": report_id,
|
||||
"timestamp": timestamp,
|
||||
"query": query,
|
||||
"detail_level": detail_level,
|
||||
"query_type": query_type,
|
||||
"model": custom_model if custom_model else config.get("model", "default"),
|
||||
"file_path": str(output_file),
|
||||
"file_size": output_file.stat().st_size,
|
||||
"creation_date": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
return report, str(output_file)
|
||||
|
||||
except Exception as e:
|
||||
|
@ -545,6 +572,111 @@ class GradioInterface:
|
|||
# Remove content between <thinking> and </thinking> tags
|
||||
import re
|
||||
return re.sub(r'<thinking>.*?</thinking>', '', text, flags=re.DOTALL)
|
||||
|
||||
def _update_report_metadata(self, report_id, metadata):
|
||||
"""
|
||||
Update the report metadata file with new report information.
|
||||
|
||||
Args:
|
||||
report_id (str): Unique identifier for the report
|
||||
metadata (dict): Report metadata to store
|
||||
"""
|
||||
try:
|
||||
# Load existing metadata
|
||||
with open(self.reports_metadata_file, 'r') as f:
|
||||
all_metadata = json.load(f)
|
||||
|
||||
# Check if report already exists
|
||||
existing_report = None
|
||||
for i, report in enumerate(all_metadata.get('reports', [])):
|
||||
if report.get('id') == report_id:
|
||||
existing_report = i
|
||||
break
|
||||
|
||||
# Update or add the report metadata
|
||||
if existing_report is not None:
|
||||
all_metadata['reports'][existing_report] = metadata
|
||||
else:
|
||||
all_metadata['reports'].append(metadata)
|
||||
|
||||
# Save updated metadata
|
||||
with open(self.reports_metadata_file, 'w') as f:
|
||||
json.dump(all_metadata, f, indent=2)
|
||||
|
||||
print(f"Updated metadata for report {report_id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error updating report metadata: {str(e)}")
|
||||
|
||||
def get_all_reports(self):
|
||||
"""
|
||||
Get all report metadata.
|
||||
|
||||
Returns:
|
||||
list: List of report metadata dictionaries
|
||||
"""
|
||||
try:
|
||||
# Load metadata
|
||||
with open(self.reports_metadata_file, 'r') as f:
|
||||
all_metadata = json.load(f)
|
||||
|
||||
# Return reports sorted by creation date (newest first)
|
||||
reports = all_metadata.get('reports', [])
|
||||
return sorted(reports, key=lambda x: x.get('creation_date', ''), reverse=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting report metadata: {str(e)}")
|
||||
return []
|
||||
|
||||
def delete_report(self, report_id):
|
||||
"""
|
||||
Delete a report and its metadata.
|
||||
|
||||
Args:
|
||||
report_id (str): ID of the report to delete
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Load metadata
|
||||
with open(self.reports_metadata_file, 'r') as f:
|
||||
all_metadata = json.load(f)
|
||||
|
||||
# Find the report
|
||||
report_to_delete = None
|
||||
for report in all_metadata.get('reports', []):
|
||||
if report.get('id') == report_id:
|
||||
report_to_delete = report
|
||||
break
|
||||
|
||||
if not report_to_delete:
|
||||
print(f"Report {report_id} not found")
|
||||
return False
|
||||
|
||||
# Delete the report file
|
||||
file_path = report_to_delete.get('file_path')
|
||||
print(f"Deleting report: report_id={report_id}, file_path={file_path}")
|
||||
if file_path and Path(file_path).exists():
|
||||
print(f"File exists: {Path(file_path).exists()}")
|
||||
Path(file_path).unlink()
|
||||
print(f"Deleted report file: {file_path}")
|
||||
else:
|
||||
print(f"File not found or file_path is missing")
|
||||
|
||||
# Remove from metadata
|
||||
all_metadata['reports'] = [r for r in all_metadata.get('reports', []) if r.get('id') != report_id]
|
||||
|
||||
# Save updated metadata
|
||||
with open(self.reports_metadata_file, 'w') as f:
|
||||
json.dump(all_metadata, f, indent=2)
|
||||
|
||||
print(f"Deleted report {report_id} from metadata")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error deleting report: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_available_models(self):
|
||||
"""
|
||||
|
@ -600,6 +732,420 @@ class GradioInterface:
|
|||
|
||||
self.model_name_to_description = model_name_to_description
|
||||
return descriptions
|
||||
|
||||
def _get_reports_for_display(self):
|
||||
"""Get reports formatted for display in the UI"""
|
||||
reports = self.get_all_reports()
|
||||
display_data = []
|
||||
|
||||
for report in reports:
|
||||
# Format timestamp for display
|
||||
timestamp = report.get('timestamp', '')
|
||||
creation_date = report.get('creation_date', '')
|
||||
if creation_date:
|
||||
try:
|
||||
# Convert ISO format to datetime and format for display
|
||||
dt = datetime.fromisoformat(creation_date)
|
||||
formatted_date = dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
formatted_date = creation_date
|
||||
else:
|
||||
formatted_date = timestamp
|
||||
|
||||
# Format file size
|
||||
file_size = report.get('file_size', 0)
|
||||
if file_size < 1024:
|
||||
formatted_size = f"{file_size} B"
|
||||
elif file_size < 1024 * 1024:
|
||||
formatted_size = f"{file_size / 1024:.1f} KB"
|
||||
else:
|
||||
formatted_size = f"{file_size / (1024 * 1024):.1f} MB"
|
||||
|
||||
# Add row to display data
|
||||
display_data.append([
|
||||
report.get('id', ''),
|
||||
report.get('query', '')[:50] + ('...' if len(report.get('query', '')) > 50 else ''),
|
||||
report.get('model', '').split('/')[-1], # Show only the model name without provider
|
||||
report.get('detail_level', ''),
|
||||
formatted_date,
|
||||
formatted_size,
|
||||
Path(report.get('file_path', '')).name, # Just the filename
|
||||
])
|
||||
|
||||
return display_data
|
||||
|
||||
def _delete_selected_reports(self, selected_choices):
|
||||
"""Delete selected reports
|
||||
|
||||
Args:
|
||||
selected_choices (list): List of selected checkbox values in format "ID: Query (Model)"
|
||||
|
||||
Returns:
|
||||
tuple: Updated reports table data and updated checkbox choices
|
||||
"""
|
||||
if not selected_choices:
|
||||
# If no reports are selected, just refresh the display
|
||||
reports_data = self._get_reports_for_display()
|
||||
choices = self._get_report_choices(reports_data)
|
||||
return reports_data, choices, "No reports selected for deletion."
|
||||
|
||||
print(f"Selected choices for deletion: {selected_choices}")
|
||||
|
||||
# Extract report IDs from selected choices
|
||||
selected_report_ids = []
|
||||
for choice in selected_choices:
|
||||
try:
|
||||
# Convert to string and handle different input formats
|
||||
choice_str = str(choice).strip().strip('"\'')
|
||||
print(f"Processing choice: '{choice_str}'")
|
||||
|
||||
# Split at the first colon to get the ID
|
||||
if ':' in choice_str:
|
||||
report_id = choice_str.split(':', 1)[0].strip()
|
||||
selected_report_ids.append(report_id)
|
||||
else:
|
||||
# If no colon, use the entire string as ID
|
||||
selected_report_ids.append(choice_str)
|
||||
print(f"Using full string as ID: '{choice_str}'")
|
||||
except Exception as e:
|
||||
print(f"Error processing choice {choice}: {e}")
|
||||
|
||||
print(f"Deleting report IDs: {selected_report_ids}")
|
||||
|
||||
# Delete selected reports
|
||||
deleted_count = 0
|
||||
for report_id in selected_report_ids:
|
||||
if self.delete_report(report_id):
|
||||
deleted_count += 1
|
||||
print(f"Successfully deleted report: {report_id}")
|
||||
else:
|
||||
print(f"Failed to delete report: {report_id}")
|
||||
|
||||
print(f"Deleted {deleted_count} reports")
|
||||
|
||||
# Refresh the table and choices
|
||||
reports_data = self._get_reports_for_display()
|
||||
choices = self._get_report_choices(reports_data)
|
||||
status_message = f"Deleted {deleted_count} report(s)."
|
||||
return reports_data, choices, status_message
|
||||
|
||||
def _download_selected_reports(self, selected_choices):
|
||||
"""Prepare selected reports for download
|
||||
|
||||
Args:
|
||||
selected_choices (list): List of selected checkbox values in format "ID: Query (Model)"
|
||||
|
||||
Returns:
|
||||
list: List of file paths to download
|
||||
"""
|
||||
if not selected_choices:
|
||||
return []
|
||||
|
||||
print(f"Selected choices for download: {selected_choices}")
|
||||
|
||||
# Extract report IDs from selected choices
|
||||
selected_report_ids = []
|
||||
for choice in selected_choices:
|
||||
try:
|
||||
# Convert to string and handle different input formats
|
||||
choice_str = str(choice).strip().strip('"\'')
|
||||
print(f"Processing choice: '{choice_str}'")
|
||||
|
||||
# Split at the first colon to get the ID
|
||||
if ':' in choice_str:
|
||||
report_id = choice_str.split(':', 1)[0].strip()
|
||||
selected_report_ids.append(report_id)
|
||||
else:
|
||||
# If no colon, use the entire string as ID
|
||||
selected_report_ids.append(choice_str)
|
||||
print(f"Using full string as ID: '{choice_str}'")
|
||||
except Exception as e:
|
||||
print(f"Error processing choice {choice}: {e}")
|
||||
|
||||
print(f"Extracted report IDs: {selected_report_ids}")
|
||||
|
||||
# Get file paths for selected reports
|
||||
all_reports = self.get_all_reports()
|
||||
files_to_download = []
|
||||
|
||||
for report_id in selected_report_ids:
|
||||
report = next((r for r in all_reports if r.get('id') == report_id), None)
|
||||
if report and "file_path" in report:
|
||||
file_path = report["file_path"]
|
||||
print(f"Downloading report: report_id={report_id}, file_path={file_path}")
|
||||
# Verify the file exists
|
||||
if os.path.exists(file_path):
|
||||
files_to_download.append(file_path)
|
||||
print(f"Added file for download: {file_path}")
|
||||
else:
|
||||
print(f"Warning: File does not exist: {file_path}")
|
||||
else:
|
||||
print(f"Warning: Could not find report with ID {report_id}")
|
||||
|
||||
return files_to_download
|
||||
|
||||
def _get_report_choices(self, reports_data):
|
||||
"""Generate choices for the checkbox group based on reports data
|
||||
|
||||
Args:
|
||||
reports_data (list): List of report data rows
|
||||
|
||||
Returns:
|
||||
list: List of choices for the checkbox group in format "ID: Query (Model)"
|
||||
"""
|
||||
choices = []
|
||||
# If reports_data is empty, return an empty list
|
||||
if not reports_data:
|
||||
return []
|
||||
|
||||
# Get all reports from the metadata file to ensure IDs are available
|
||||
all_reports = self.get_all_reports()
|
||||
|
||||
# Create a mapping of report IDs to their full data
|
||||
report_map = {report.get('id', ''): report for report in all_reports}
|
||||
|
||||
for row in reports_data:
|
||||
try:
|
||||
report_id = row[0]
|
||||
if not report_id:
|
||||
continue
|
||||
|
||||
# Get data from the table row
|
||||
query = row[1]
|
||||
model = row[2]
|
||||
|
||||
# Format: "ID: Query (Model)"
|
||||
choice_text = f"{report_id}: {query} ({model})"
|
||||
choices.append(choice_text)
|
||||
except (IndexError, TypeError) as e:
|
||||
print(f"Error processing report row: {e}")
|
||||
continue
|
||||
|
||||
return choices
|
||||
|
||||
def _refresh_reports_with_html(self):
|
||||
"""Refresh the reports list with updated HTML
|
||||
|
||||
Returns:
|
||||
tuple: Updated reports data, HTML content, and reset hidden field value
|
||||
"""
|
||||
reports_data = self._get_reports_for_display()
|
||||
choices = self._get_report_choices(reports_data)
|
||||
html_content = create_checkbox_html(choices)
|
||||
return reports_data, html_content, "[]" # Reset the hidden field
|
||||
|
||||
def _delete_selected_reports_with_html(self, selected_json):
|
||||
"""Delete selected reports and return updated HTML
|
||||
|
||||
Args:
|
||||
selected_json (str): JSON string containing selected report IDs
|
||||
|
||||
Returns:
|
||||
tuple: Updated reports data, HTML content, reset hidden field value, and status message
|
||||
"""
|
||||
try:
|
||||
# Parse JSON with error handling
|
||||
if not selected_json or selected_json == "[]":
|
||||
selected = []
|
||||
else:
|
||||
try:
|
||||
selected = json.loads(selected_json)
|
||||
print(f"Parsed JSON selections: {selected}")
|
||||
except Exception as json_err:
|
||||
print(f"JSON parse error: {json_err}")
|
||||
# If JSON parsing fails, try to extract values directly
|
||||
selected = [s.strip(' "') for s in selected_json.strip('[]').split(',')]
|
||||
print(f"Fallback parsing to: {selected}")
|
||||
|
||||
# Delete reports
|
||||
updated_table, _, message = self._delete_selected_reports(selected)
|
||||
choices = self._get_report_choices(updated_table)
|
||||
html_content = create_checkbox_html(choices)
|
||||
return updated_table, html_content, "[]", f"{message}"
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return self._get_reports_for_display(), create_checkbox_html([]), "[]", f"Error: {str(e)}"
|
||||
|
||||
def _download_with_html(self, selected_json):
|
||||
"""Prepare selected reports for download with improved JSON parsing
|
||||
|
||||
Args:
|
||||
selected_json (str): JSON string containing selected report IDs
|
||||
|
||||
Returns:
|
||||
list: Files prepared for download
|
||||
"""
|
||||
try:
|
||||
# Parse JSON with error handling
|
||||
if not selected_json or selected_json == "[]":
|
||||
selected = []
|
||||
else:
|
||||
try:
|
||||
selected = json.loads(selected_json)
|
||||
print(f"Parsed JSON selections for download: {selected}")
|
||||
except Exception as json_err:
|
||||
print(f"JSON parse error: {json_err}")
|
||||
# If JSON parsing fails, try to extract values directly
|
||||
selected = [s.strip(' "') for s in selected_json.strip('[]').split(',')]
|
||||
print(f"Fallback parsing to: {selected}")
|
||||
|
||||
# Get file paths for download
|
||||
files = self._download_selected_reports(selected)
|
||||
return files
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
def _cleanup_old_reports(self, days):
|
||||
"""Delete reports older than the specified number of days
|
||||
|
||||
Args:
|
||||
days (int): Number of days to keep reports for
|
||||
|
||||
Returns:
|
||||
list: Updated reports table data
|
||||
"""
|
||||
try:
|
||||
if days <= 0:
|
||||
print("Cleanup skipped - days parameter is 0 or negative")
|
||||
return self._get_reports_for_display()
|
||||
|
||||
# Calculate cutoff date
|
||||
from datetime import timedelta
|
||||
cutoff_date = datetime.now() - timedelta(days=days)
|
||||
cutoff_str = cutoff_date.isoformat()
|
||||
print(f"Cleaning up reports older than {cutoff_date.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
# Get all reports
|
||||
all_reports = self.get_all_reports()
|
||||
print(f"Found {len(all_reports)} total reports")
|
||||
reports_to_delete = []
|
||||
|
||||
# Find reports older than cutoff date
|
||||
for report in all_reports:
|
||||
creation_date = report.get('creation_date', '')
|
||||
if not creation_date:
|
||||
print(f"Warning: Report {report.get('id')} has no creation date")
|
||||
continue
|
||||
|
||||
if creation_date < cutoff_str:
|
||||
reports_to_delete.append(report.get('id'))
|
||||
print(f"Marking report {report.get('id')} from {creation_date} for deletion")
|
||||
|
||||
print(f"Found {len(reports_to_delete)} reports to delete")
|
||||
|
||||
# Delete old reports
|
||||
deleted_count = 0
|
||||
for report_id in reports_to_delete:
|
||||
if self.delete_report(report_id):
|
||||
deleted_count += 1
|
||||
|
||||
print(f"Successfully deleted {deleted_count} reports")
|
||||
|
||||
# Refresh the table
|
||||
updated_display = self._get_reports_for_display()
|
||||
print(f"Returning updated display with {len(updated_display)} reports")
|
||||
return updated_display
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in cleanup_old_reports: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# Return current display data in case of error
|
||||
return self._get_reports_for_display()
|
||||
|
||||
def migrate_existing_reports(self):
|
||||
"""Migrate existing reports from the root directory to the reports directory structure
|
||||
|
||||
Returns:
|
||||
str: Status message indicating the result of the migration
|
||||
"""
|
||||
import re
|
||||
import shutil
|
||||
import os
|
||||
|
||||
# Pattern to match report files like report_20250317_122351_llama-3.3-70b-versatile.md
|
||||
report_pattern = re.compile(r'report_(?P<date>\d{8})_(?P<time>\d{6})_?(?P<model>.*?)?\.md$')
|
||||
|
||||
# Get the root directory
|
||||
root_dir = Path(__file__).parent.parent
|
||||
|
||||
# Find all report files in the root directory
|
||||
migrated_count = 0
|
||||
for file_path in root_dir.glob('report_*.md'):
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
# Extract information from the filename
|
||||
match = report_pattern.match(file_path.name)
|
||||
if not match:
|
||||
continue
|
||||
|
||||
date_str = match.group('date')
|
||||
time_str = match.group('time')
|
||||
model = match.group('model') or 'unknown'
|
||||
|
||||
# Format date for directory structure (YYYY-MM-DD)
|
||||
try:
|
||||
year = date_str[:4]
|
||||
month = date_str[4:6]
|
||||
day = date_str[6:8]
|
||||
formatted_date = f"{year}-{month}-{day}"
|
||||
|
||||
# Create timestamp for metadata
|
||||
timestamp = f"{year}-{month}-{day} {time_str[:2]}:{time_str[2:4]}:{time_str[4:6]}"
|
||||
creation_date = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S").isoformat()
|
||||
except ValueError:
|
||||
# If date parsing fails, use current date
|
||||
formatted_date = datetime.now().strftime("%Y-%m-%d")
|
||||
creation_date = datetime.now().isoformat()
|
||||
|
||||
# Create directory for the date if it doesn't exist
|
||||
date_dir = self.reports_dir / formatted_date
|
||||
date_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Generate a unique report ID
|
||||
report_id = f"{date_str}_{time_str}"
|
||||
|
||||
# Copy the file to the new location
|
||||
new_file_path = date_dir / file_path.name
|
||||
shutil.copy2(file_path, new_file_path)
|
||||
|
||||
# Read the report content to extract query if possible
|
||||
query = ""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read(1000) # Read just the beginning to find the query
|
||||
# Try to extract query from title or first few lines
|
||||
title_match = re.search(r'#\s*(.+?)\n', content)
|
||||
if title_match:
|
||||
query = title_match.group(1).strip()
|
||||
else:
|
||||
# Just use the first line as query
|
||||
query = content.split('\n')[0].strip()
|
||||
except Exception as e:
|
||||
print(f"Error reading file {file_path}: {e}")
|
||||
|
||||
# Create metadata for the report
|
||||
file_size = os.path.getsize(file_path)
|
||||
metadata = {
|
||||
"id": report_id,
|
||||
"query": query,
|
||||
"model": model,
|
||||
"detail_level": "unknown", # We don't know the detail level from the filename
|
||||
"timestamp": timestamp,
|
||||
"creation_date": creation_date,
|
||||
"file_path": str(new_file_path),
|
||||
"file_size": file_size
|
||||
}
|
||||
|
||||
# Update the metadata file
|
||||
self._update_report_metadata(report_id, metadata)
|
||||
migrated_count += 1
|
||||
|
||||
return f"Migrated {migrated_count} existing reports to the new directory structure."
|
||||
|
||||
def create_interface(self):
|
||||
"""
|
||||
|
@ -624,53 +1170,9 @@ class GradioInterface:
|
|||
"""
|
||||
)
|
||||
|
||||
# Create tabs for different sections
|
||||
with gr.Tabs() as tabs:
|
||||
with gr.TabItem("Search"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
search_query_input = gr.Textbox(
|
||||
label="Research Query",
|
||||
placeholder="Enter your research question here...",
|
||||
lines=3
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
search_num_results = gr.Slider(
|
||||
minimum=5,
|
||||
maximum=50,
|
||||
value=20,
|
||||
step=5,
|
||||
label="Results Per Engine"
|
||||
)
|
||||
search_use_reranker = gr.Checkbox(
|
||||
label="Use Semantic Reranker",
|
||||
value=True,
|
||||
info="Uses Jina AI's reranker for more relevant results"
|
||||
)
|
||||
search_button = gr.Button("Search", variant="primary")
|
||||
|
||||
gr.Examples(
|
||||
examples=[
|
||||
["What are the latest advancements in quantum computing?"],
|
||||
["Compare transformer and RNN architectures for NLP tasks"],
|
||||
["Explain the environmental impact of electric vehicles"],
|
||||
["What recent actions has Trump taken regarding tariffs?"],
|
||||
["What are the recent papers on large language model alignment?"],
|
||||
["What are the main research findings on climate change adaptation strategies in agriculture?"]
|
||||
],
|
||||
inputs=search_query_input
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
search_results_output = gr.Markdown(label="Results")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
search_file_output = gr.Textbox(
|
||||
label="Results saved to file",
|
||||
interactive=False
|
||||
)
|
||||
|
||||
# Report Generation Tab
|
||||
with gr.TabItem("Generate Report"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
|
@ -779,19 +1281,82 @@ class GradioInterface:
|
|||
|
||||
gr.Markdown(f"### Detail Levels\n{detail_levels_info}")
|
||||
gr.Markdown(f"### Query Types\n{query_types_info}")
|
||||
|
||||
# Report Management Tab - Reimplemented from scratch
|
||||
with gr.TabItem("Manage Reports"):
|
||||
with gr.Row():
|
||||
gr.Markdown("## Report Management")
|
||||
|
||||
with gr.Row():
|
||||
gr.Markdown("Select reports to download or delete. You can filter and sort the reports using the table controls.")
|
||||
|
||||
# Get the reports data
|
||||
reports_data = self._get_reports_for_display()
|
||||
|
||||
# Create a state to store selected report IDs
|
||||
selected_report_ids = gr.State([])
|
||||
|
||||
# We've removed the DataTable as requested by the user
|
||||
|
||||
# Selection controls
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
# Create a checkbox group for selecting reports
|
||||
report_choices = self._get_report_choices(reports_data)
|
||||
reports_checkbox_group = gr.CheckboxGroup(
|
||||
choices=report_choices,
|
||||
label="Select Reports",
|
||||
info="Check the reports you want to download or delete",
|
||||
interactive=True
|
||||
)
|
||||
|
||||
with gr.Column(scale=1):
|
||||
# Action buttons
|
||||
with gr.Row():
|
||||
refresh_button = gr.Button("Refresh List", size="sm")
|
||||
|
||||
with gr.Row():
|
||||
select_all_button = gr.Button("Select All", size="sm")
|
||||
clear_selection_button = gr.Button("Clear Selection", size="sm")
|
||||
|
||||
with gr.Row():
|
||||
download_button = gr.Button("Download Selected", size="sm")
|
||||
delete_button = gr.Button("Delete Selected", variant="stop", size="sm")
|
||||
|
||||
with gr.Row():
|
||||
cleanup_days = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=90,
|
||||
value=30,
|
||||
step=1,
|
||||
label="Delete Reports Older Than (Days)",
|
||||
info="Set to 0 to disable automatic cleanup"
|
||||
)
|
||||
cleanup_button = gr.Button("Clean Up Old Reports", size="sm")
|
||||
|
||||
# File download component
|
||||
with gr.Row():
|
||||
file_output = gr.File(
|
||||
label="Downloaded Reports",
|
||||
file_count="multiple",
|
||||
type="filepath",
|
||||
interactive=False
|
||||
)
|
||||
|
||||
# Status message
|
||||
with gr.Row():
|
||||
status_message = gr.Markdown("")
|
||||
|
||||
# Migration button for existing reports
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Markdown("### Migrate Existing Reports")
|
||||
gr.Markdown("Use this button to migrate existing reports from the root directory to the new reports directory structure.")
|
||||
migrate_button = gr.Button("Migrate Existing Reports", variant="primary")
|
||||
|
||||
# Set up event handlers
|
||||
search_button.click(
|
||||
fn=self.process_query,
|
||||
inputs=[search_query_input, search_num_results, search_use_reranker],
|
||||
outputs=[search_results_output, search_file_output]
|
||||
)
|
||||
|
||||
# Connect the progress callback to the report button
|
||||
# Progress display is now handled entirely by Gradio's built-in progress tracking
|
||||
|
||||
# Update the progress tracking in the generate_report method
|
||||
async def generate_report_with_progress(query, detail_level, query_type, model_name, rerank, token_budget, initial_results, final_results):
|
||||
async def generate_report_with_progress(query, detail_level, query_type, model_name, process_thinking, initial_results, final_results):
|
||||
# Set up progress tracking
|
||||
progress_data = gr.Progress(track_tqdm=True)
|
||||
|
||||
|
@ -799,17 +1364,177 @@ class GradioInterface:
|
|||
print(f"Model selected from UI dropdown: {model_name}")
|
||||
|
||||
# Call the original generate_report method
|
||||
result = await self.generate_report(query, detail_level, query_type, model_name, rerank, token_budget, initial_results, final_results)
|
||||
result = await self.generate_report(
|
||||
query,
|
||||
detail_level,
|
||||
query_type,
|
||||
model_name,
|
||||
None, # results_file is now None since we removed the search tab
|
||||
process_thinking,
|
||||
initial_results,
|
||||
final_results
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
report_button.click(
|
||||
fn=lambda q, d, t, m, r, p, i, f: asyncio.run(generate_report_with_progress(q, d, t, m, r, p, i, f)),
|
||||
fn=lambda q, d, t, m, p, i, f: asyncio.run(generate_report_with_progress(q, d, t, m, p, i, f)),
|
||||
inputs=[report_query_input, report_detail_level, report_query_type, report_custom_model,
|
||||
search_file_output, report_process_thinking, initial_results_slider, final_results_slider],
|
||||
report_process_thinking, initial_results_slider, final_results_slider],
|
||||
outputs=[report_output, report_file_output]
|
||||
)
|
||||
|
||||
# Report Management Tab Event Handlers
|
||||
|
||||
# Refresh reports list
|
||||
def refresh_reports_list():
|
||||
"""Refresh the reports list and update the UI components"""
|
||||
reports_data = self._get_reports_for_display()
|
||||
report_choices = self._get_report_choices(reports_data)
|
||||
return reports_data, report_choices, "Reports list refreshed."
|
||||
|
||||
refresh_button.click(
|
||||
fn=refresh_reports_list,
|
||||
inputs=[],
|
||||
outputs=[reports_checkbox_group, reports_checkbox_group, status_message]
|
||||
)
|
||||
|
||||
# Select all reports
|
||||
def select_all_reports():
|
||||
"""Select all reports in the checkbox group"""
|
||||
report_choices = self._get_report_choices(self._get_reports_for_display())
|
||||
return report_choices, "Selected all reports."
|
||||
|
||||
select_all_button.click(
|
||||
fn=select_all_reports,
|
||||
inputs=[],
|
||||
outputs=[reports_checkbox_group, status_message]
|
||||
)
|
||||
|
||||
# Clear selection
|
||||
def clear_selection():
|
||||
"""Clear the selection in the checkbox group"""
|
||||
return [], "Selection cleared."
|
||||
|
||||
clear_selection_button.click(
|
||||
fn=clear_selection,
|
||||
inputs=[],
|
||||
outputs=[reports_checkbox_group, status_message]
|
||||
)
|
||||
|
||||
# Download selected reports
|
||||
def download_selected_reports(selected_choices):
|
||||
"""Download selected reports"""
|
||||
if not selected_choices:
|
||||
return [], "No reports selected for download."
|
||||
|
||||
print(f"Selected choices for download: {selected_choices}")
|
||||
files = self._download_selected_reports(selected_choices)
|
||||
|
||||
if files:
|
||||
return files, f"Prepared {len(files)} report(s) for download."
|
||||
else:
|
||||
return [], "No files found for the selected reports."
|
||||
|
||||
download_button.click(
|
||||
fn=download_selected_reports,
|
||||
inputs=[reports_checkbox_group],
|
||||
outputs=[file_output, status_message]
|
||||
)
|
||||
|
||||
# Delete selected reports
|
||||
def delete_selected_reports(selected_choices):
|
||||
"""Delete selected reports and update the UI"""
|
||||
if not selected_choices:
|
||||
return self._get_reports_for_display(), [], "No reports selected for deletion."
|
||||
|
||||
print(f"Selected choices for deletion: {selected_choices}")
|
||||
|
||||
# Extract report IDs from selected choices
|
||||
selected_report_ids = []
|
||||
for choice in selected_choices:
|
||||
try:
|
||||
# Split at the first colon to get the ID
|
||||
if ':' in choice:
|
||||
report_id = choice.split(':', 1)[0].strip()
|
||||
selected_report_ids.append(report_id)
|
||||
else:
|
||||
# If no colon, use the entire string as ID
|
||||
selected_report_ids.append(choice)
|
||||
except Exception as e:
|
||||
print(f"Error processing choice {choice}: {e}")
|
||||
|
||||
# Delete selected reports
|
||||
deleted_count = 0
|
||||
for report_id in selected_report_ids:
|
||||
if self.delete_report(report_id):
|
||||
deleted_count += 1
|
||||
|
||||
# Refresh the table and choices
|
||||
updated_reports_data = self._get_reports_for_display()
|
||||
updated_choices = self._get_report_choices(updated_reports_data)
|
||||
|
||||
return updated_choices, f"Deleted {deleted_count} report(s)."
|
||||
|
||||
delete_button.click(
|
||||
fn=delete_selected_reports,
|
||||
inputs=[reports_checkbox_group],
|
||||
outputs=[reports_checkbox_group, status_message]
|
||||
)
|
||||
|
||||
# Clean up old reports
|
||||
def cleanup_old_reports(days):
|
||||
"""Delete reports older than the specified number of days"""
|
||||
if days <= 0:
|
||||
return self._get_reports_for_display(), self._get_report_choices(self._get_reports_for_display()), "Cleanup skipped - days parameter is 0 or negative."
|
||||
|
||||
updated_reports_data = self._cleanup_old_reports(days)
|
||||
updated_choices = self._get_report_choices(updated_reports_data)
|
||||
|
||||
return updated_reports_data, updated_choices, f"Reports older than {days} days have been deleted."
|
||||
|
||||
cleanup_button.click(
|
||||
fn=cleanup_old_reports,
|
||||
inputs=[cleanup_days],
|
||||
outputs=[reports_checkbox_group, status_message]
|
||||
)
|
||||
|
||||
# Migration button event handler
|
||||
def migrate_existing_reports():
|
||||
"""Migrate existing reports from the root directory to the reports directory structure"""
|
||||
print("Starting migration of existing reports...")
|
||||
status = self.migrate_existing_reports()
|
||||
print("Migration completed, refreshing display...")
|
||||
|
||||
# Refresh the reports list
|
||||
updated_reports_data = self._get_reports_for_display()
|
||||
updated_choices = self._get_report_choices(updated_reports_data)
|
||||
|
||||
return status, updated_reports_data, updated_choices
|
||||
|
||||
migrate_button.click(
|
||||
fn=migrate_existing_reports,
|
||||
inputs=[],
|
||||
outputs=[status_message, reports_checkbox_group]
|
||||
)
|
||||
|
||||
# Initialize the UI on page load
|
||||
def init_reports_ui():
|
||||
"""Initialize the reports UI with current data"""
|
||||
print("Initializing reports UI...")
|
||||
reports_data = self._get_reports_for_display()
|
||||
choices = self._get_report_choices(reports_data)
|
||||
|
||||
print(f"Initializing reports UI with {len(reports_data)} reports and {len(choices)} choices")
|
||||
|
||||
return choices, "Reports management initialized successfully."
|
||||
|
||||
interface.load(
|
||||
fn=init_reports_ui,
|
||||
inputs=[],
|
||||
outputs=[reports_checkbox_group, status_message]
|
||||
)
|
||||
|
||||
return interface
|
||||
|
||||
def launch(self, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue