ruslanmv commited on
Commit
49fdd56
·
1 Parent(s): f496b94
Files changed (3) hide show
  1. Dockerfile +4 -9
  2. main.py +41 -23
  3. milvus_singleton.py +9 -14
Dockerfile CHANGED
@@ -2,30 +2,25 @@ FROM python:3.10.8
2
 
3
  WORKDIR /app
4
 
5
- # Copy only requirements.txt first to leverage Docker caching
6
- COPY requirements.txt /app/
7
 
8
- # Create cache and milvus_data directories and set permissions
9
  RUN mkdir -p /app/cache /app/milvus_data && chmod -R 777 /app/cache /app/milvus_data
10
 
11
- # Install dependencies
12
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
13
 
14
- # Create a non-root user
15
  RUN useradd -m -u 1000 user
 
16
  USER user
17
 
18
- # Set environment variables for Hugging Face cache and Milvus data
19
  ENV HF_HOME=/app/cache \
20
  HF_MODULES_CACHE=/app/cache/hf_modules \
21
  MILVUS_DATA_DIR=/app/milvus_data \
22
  HF_WORKER_COUNT=1
23
 
24
- # Copy the application code (now main.py is at the root)
25
  COPY . /app
26
 
27
- # Expose the port Uvicorn will run on
28
  EXPOSE 7860
29
 
30
- # Start Uvicorn (main:app is correct now)
31
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
2
 
3
  WORKDIR /app
4
 
5
+ COPY requirements.txt /app/requirements.txt
 
6
 
 
7
  RUN mkdir -p /app/cache /app/milvus_data && chmod -R 777 /app/cache /app/milvus_data
8
 
 
9
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
10
 
 
11
  RUN useradd -m -u 1000 user
12
+
13
  USER user
14
 
 
15
  ENV HF_HOME=/app/cache \
16
  HF_MODULES_CACHE=/app/cache/hf_modules \
17
  MILVUS_DATA_DIR=/app/milvus_data \
18
  HF_WORKER_COUNT=1
19
 
 
20
  COPY . /app
21
 
22
+ # Expose port for Uvicorn
23
  EXPOSE 7860
24
 
25
+ # Command to run Uvicorn
26
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py CHANGED
@@ -1,44 +1,44 @@
1
  from io import BytesIO
2
- from fastapi import FastAPI, File, UploadFile
3
  from fastapi.encoders import jsonable_encoder
4
  from fastapi.responses import JSONResponse
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
- from pymilvus import utility, Collection, CollectionSchema, FieldSchema, DataType
 
 
8
  import os
9
  import pypdf
10
  from uuid import uuid4
 
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
12
  from sentence_transformers import SentenceTransformer
13
  import torch
14
  from milvus_singleton import MilvusClientSingleton
15
 
16
- # Set environment variables for Hugging Face cache
17
  os.environ['HF_HOME'] = '/app/cache'
18
  os.environ['HF_MODULES_CACHE'] = '/app/cache/hf_modules'
 
 
 
 
 
 
19
 
20
- # Embedding model
21
- embedding_model = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5',
22
- trust_remote_code=True,
23
- device='cuda' if torch.cuda.is_available() else 'cpu',
24
- cache_folder='/app/cache')
25
-
26
- # Milvus connection details
27
- collection_name = "rag"
28
- milvus_uri = os.getenv("MILVUS_URI", "http://localhost:19530") # Correct URI for Milvus
29
 
30
- # Initialize Milvus client using singleton
31
- milvus_client = MilvusClientSingleton.get_instance(uri=milvus_uri)
32
 
33
- def document_to_embeddings(content: str) -> list:
34
  return embedding_model.encode(content, show_progress_bar=True)
35
 
36
  app = FastAPI()
37
 
38
- # Add CORS middleware
39
  app.add_middleware(
40
  CORSMiddleware,
41
- allow_origins=["*"], # Replace with allowed origins for production
42
  allow_credentials=True,
43
  allow_methods=["*"],
44
  allow_headers=["*"],
@@ -53,14 +53,20 @@ def create_a_collection(milvus_client, collection_name):
53
  id_field = FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=40, is_primary=True)
54
  content_field = FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=4096)
55
  vector_field = FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024)
 
56
  # Define the schema for the collection
57
  schema = CollectionSchema(fields=[id_field, content_field, vector_field])
 
58
  # Create the collection
59
  milvus_client.create_collection(
60
  collection_name=collection_name,
61
  schema=schema
62
  )
 
 
 
63
  collection = Collection(name=collection_name)
 
64
  # Create an index for the collection
65
  # IVF_FLAT index is used here, with metric_type COSINE
66
  index_params = {
@@ -70,10 +76,11 @@ def create_a_collection(milvus_client, collection_name):
70
  "nlist": 128
71
  }
72
  }
 
73
  # Create the index on the vector field
74
  collection.create_index(
75
  field_name="vector",
76
- index_params=index_params
77
  )
78
 
79
  @app.get("/")
@@ -83,15 +90,21 @@ async def root():
83
  @app.post("/insert")
84
  async def insert(file: UploadFile = File(...)):
85
  contents = await file.read()
 
86
  if not milvus_client.has_collection(collection_name):
87
  create_a_collection(milvus_client, collection_name)
 
88
  contents = pypdf.PdfReader(BytesIO(contents))
 
89
  extracted_text = ""
90
  for page_num in range(len(contents.pages)):
91
  page = contents.pages[page_num]
92
  extracted_text += page.extract_text()
93
- splitted_document_data = split_documents(extracted_text)
 
 
94
  print(splitted_document_data)
 
95
  data_objects = []
96
  for doc in splitted_document_data:
97
  data = {
@@ -100,32 +113,37 @@ async def insert(file: UploadFile = File(...)):
100
  "content": doc,
101
  }
102
  data_objects.append(data)
 
103
  print(data_objects)
 
104
  try:
105
  milvus_client.insert(collection_name=collection_name, data=data_objects)
 
106
  except Exception as e:
107
  raise JSONResponse(status_code=500, content={"error": str(e)})
108
  else:
109
  return JSONResponse(status_code=200, content={"result": 'good'})
110
-
111
  class RAGRequest(BaseModel):
112
  question: str
113
-
114
  @app.post("/rag")
115
  async def rag(request: RAGRequest):
116
  question = request.question
117
  if not question:
118
  return JSONResponse(status_code=400, content={"message": "Please a question!"})
 
119
  try:
120
  search_res = milvus_client.search(
121
  collection_name=collection_name,
122
  data=[
123
  document_to_embeddings(question)
124
- ],
125
- limit=5, # Return top 5 results
126
  # search_params={"metric_type": "COSINE"}, # Inner product distance
127
  output_fields=["content"], # Return the text field
128
  )
 
129
  retrieved_lines_with_distances = [
130
  (res["entity"]["content"]) for res in search_res[0]
131
  ]
 
1
  from io import BytesIO
2
+ from fastapi import FastAPI, Form, Depends, Request, File, UploadFile
3
  from fastapi.encoders import jsonable_encoder
4
  from fastapi.responses import JSONResponse
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
+ from pymilvus import connections
8
+
9
+
10
  import os
11
  import pypdf
12
  from uuid import uuid4
13
+
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from pymilvus import MilvusClient, db, utility, Collection, CollectionSchema, FieldSchema, DataType
16
  from sentence_transformers import SentenceTransformer
17
  import torch
18
  from milvus_singleton import MilvusClientSingleton
19
 
20
+
21
  os.environ['HF_HOME'] = '/app/cache'
22
  os.environ['HF_MODULES_CACHE'] = '/app/cache/hf_modules'
23
+ embedding_model = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5',
24
+ trust_remote_code=True,
25
+ device='cuda' if torch.cuda.is_available() else 'cpu',
26
+ cache_folder='/app/cache'
27
+ )
28
+ collection_name="rag"
29
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # milvus_client = MilvusClientSingleton.get_instance(uri="/app/milvus_data/milvus_demo.db")
32
+ milvus_client = MilvusClient(uri="/app/milvus_data/milvus_demo.db")
33
 
34
+ def document_to_embeddings(content:str) -> list:
35
  return embedding_model.encode(content, show_progress_bar=True)
36
 
37
  app = FastAPI()
38
 
 
39
  app.add_middleware(
40
  CORSMiddleware,
41
+ allow_origins=["*"], # Replace with the list of allowed origins for production
42
  allow_credentials=True,
43
  allow_methods=["*"],
44
  allow_headers=["*"],
 
53
  id_field = FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=40, is_primary=True)
54
  content_field = FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=4096)
55
  vector_field = FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024)
56
+
57
  # Define the schema for the collection
58
  schema = CollectionSchema(fields=[id_field, content_field, vector_field])
59
+
60
  # Create the collection
61
  milvus_client.create_collection(
62
  collection_name=collection_name,
63
  schema=schema
64
  )
65
+
66
+ connections.connect(uri="/app/milvus_data/milvus_demo.db")
67
+
68
  collection = Collection(name=collection_name)
69
+
70
  # Create an index for the collection
71
  # IVF_FLAT index is used here, with metric_type COSINE
72
  index_params = {
 
76
  "nlist": 128
77
  }
78
  }
79
+
80
  # Create the index on the vector field
81
  collection.create_index(
82
  field_name="vector",
83
+ index_params=index_params # Pass the dictionary, not a string
84
  )
85
 
86
  @app.get("/")
 
90
  @app.post("/insert")
91
  async def insert(file: UploadFile = File(...)):
92
  contents = await file.read()
93
+
94
  if not milvus_client.has_collection(collection_name):
95
  create_a_collection(milvus_client, collection_name)
96
+
97
  contents = pypdf.PdfReader(BytesIO(contents))
98
+
99
  extracted_text = ""
100
  for page_num in range(len(contents.pages)):
101
  page = contents.pages[page_num]
102
  extracted_text += page.extract_text()
103
+
104
+ splitted_document_data = split_documents(extracted_text)
105
+
106
  print(splitted_document_data)
107
+
108
  data_objects = []
109
  for doc in splitted_document_data:
110
  data = {
 
113
  "content": doc,
114
  }
115
  data_objects.append(data)
116
+
117
  print(data_objects)
118
+
119
  try:
120
  milvus_client.insert(collection_name=collection_name, data=data_objects)
121
+
122
  except Exception as e:
123
  raise JSONResponse(status_code=500, content={"error": str(e)})
124
  else:
125
  return JSONResponse(status_code=200, content={"result": 'good'})
126
+
127
  class RAGRequest(BaseModel):
128
  question: str
129
+
130
  @app.post("/rag")
131
  async def rag(request: RAGRequest):
132
  question = request.question
133
  if not question:
134
  return JSONResponse(status_code=400, content={"message": "Please a question!"})
135
+
136
  try:
137
  search_res = milvus_client.search(
138
  collection_name=collection_name,
139
  data=[
140
  document_to_embeddings(question)
141
+ ],
142
+ limit=5, # Return top 3 results
143
  # search_params={"metric_type": "COSINE"}, # Inner product distance
144
  output_fields=["content"], # Return the text field
145
  )
146
+
147
  retrieved_lines_with_distances = [
148
  (res["entity"]["content"]) for res in search_res[0]
149
  ]
milvus_singleton.py CHANGED
@@ -7,21 +7,16 @@ class MilvusClientSingleton:
7
  @staticmethod
8
  def get_instance(uri):
9
  if MilvusClientSingleton._instance is None:
10
- MilvusClientSingleton(uri)
 
 
 
 
 
 
11
  return MilvusClientSingleton._instance
12
 
13
- def __init__(self, uri):
14
  if MilvusClientSingleton._instance is not None:
15
  raise Exception("This class is a singleton!")
16
- try:
17
- # Use connections.connect() to establish the connection
18
- connections.connect(uri=uri)
19
- self._instance = connections # Store the connections object
20
- print(f"Successfully connected to Milvus at {uri}")
21
- except ConnectionConfigException as e:
22
- print(f"Error connecting to Milvus: {e}")
23
- raise
24
-
25
- def __getattr__(self, name):
26
- # Delegate attribute access to the default connection
27
- return getattr(connections, name)
 
7
  @staticmethod
8
  def get_instance(uri):
9
  if MilvusClientSingleton._instance is None:
10
+ MilvusClientSingleton()
11
+ # Initialize the client here
12
+ try:
13
+ MilvusClientSingleton._instance = connections.connect(uri=uri)
14
+ except ConnectionConfigException as e:
15
+ print(f"Error connecting to Milvus: {e}")
16
+ # Handle error appropriately
17
  return MilvusClientSingleton._instance
18
 
19
+ def __init__(self):
20
  if MilvusClientSingleton._instance is not None:
21
  raise Exception("This class is a singleton!")
22
+ self._instance = None