Skip to content

Commit b0408b1

Browse files
committed
implementing all features
1 parent d337712 commit b0408b1

22 files changed

+385
-129
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,6 @@ dmypy.json
106106

107107
# Other
108108
.DS_Store
109+
110+
# Temporary folder
111+
tmp/

README.md

+76-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,80 @@
11
# LLM-based Web Crawler
22

3+
An scalable web crawler, here a list of the feature of this crawler:
4+
5+
* This service can crawl recursively the web storing links it's text and the corresponding text embedding.
6+
* We use a large language model (e.g Bert) to obtain the text embeddings, i.e. a vector representation of the text present at each webiste.
7+
* The service is scalable, we use Ray to spread across multiple workers.
8+
* The entries are stored into a vector database. Vector databases are ideal to save and retrieve samples according to a vector representation.
9+
10+
By saving the representations into a vector database, you can retrieve similar pages according to how close two vectors are. This is critical for a browser to retrieve the most relevant results.
11+
12+
# Start the head and the worker nodes in Ray
13+
14+
## Head node
15+
16+
1. Setup the head node
17+
18+
```sh
19+
ray start --head
20+
```
21+
22+
2. Connect your program to the head node
23+
24+
```py
25+
import ray
26+
27+
# Connect to the head
28+
ray.init("auto")
29+
```
30+
31+
In case you want to stop ray node:
32+
```sh
33+
ray stop
34+
```
35+
36+
Or checking the status:
37+
```sh
38+
ray status
39+
```
40+
41+
## Worker node
42+
43+
1. Initialize the worker node
44+
45+
```sh
46+
ray start
47+
```
48+
49+
The worker node does not need to have the code implementation as the head node will serialize and submit the arguments and implementation to the workers.
50+
51+
# Large Language Model
52+
53+
For our use case, we simply use [BERT](https://arxiv.org/abs/1810.04805) model implemented by [Huggingface](https://huggingface.co/) to extract embeddings from the web text. More precisely, we use [bert-base-uncased](https://huggingface.co/bert-base-uncased). Note that the code is agnostic and new models could be registered and added with few lines of code, take a look to `llm/best.py`.
54+
55+
# Saving crawled data
56+
57+
We use [Milvus](https://milvus.io/) as our main database administrator software. We use a vector-style database due to its inherited capability of searching and saving entries based on vector representations (embeddings).
58+
59+
## Milvus lite
60+
61+
Start your standalone Milvus server as follows, I suggest using an multiplexer software such as `tmux`:
62+
63+
```sh
64+
tmux new -s milvus
65+
milvus-server
66+
```
67+
68+
## Docker compose
69+
70+
You can also use the official `docker compose` template:
71+
72+
```sh
73+
docker compose --file milvus-docker-compose.yml up -d
74+
```
75+
376
## Reference
477

5-
* [Ray Documentation]()
6-
* [Ray in 5 Min]()
78+
* [Ray Documentation](https://docs.ray.io/en/latest/ray-core/examples/gentle_walkthrough.html)
79+
* [Milvus](https://milvus.io/)
80+
* [Huggingface](https://huggingface.co/)

crawl.py

+28-19
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,40 @@
1+
import argparse
2+
13
import ray
2-
from worker import WebCrawler
34

5+
from task import WebCrawler
46

5-
if __name__ == "__main__":
6-
parser = argparse.ArgumentParser(
7-
prog='Distributed NLP Web Crawler',
8-
description='This program can crawl the web and store text embeddings',)
9-
parser.add_argument('-u', '--initial-url', )
10-
parser.add_argument('-b', '--initial-url', )
11-
parser.add_argument('-db', '--db-url', default='http://localhost')
12-
parser.add_argument('-lm', '--language-model', default='bert-base-uncased')
13-
parser.add_argument('-m', '--max-depth', default=2)
14-
args = parser.parse_args()
157

16-
# Initialize Ray
17-
ray.init()
18-
8+
def main(args):
9+
# Prior requisite is to run `ray start --head` in the terminal
10+
# and connect to the existing Ray cluster with the following line
11+
ray.init(address="auto")
12+
1913
# Instantiate Ray worker code
20-
crawler = WebCrawler.remote(
21-
args.initial_url,
22-
args.max_depth
23-
)
14+
crawler = WebCrawler.remote()
2415

2516
print("Starting to crawl...")
26-
ray.get(crawler.crawl.remote(initial_url, 0)) # Initiate the crawling remotely
17+
ray.get(
18+
[crawler.crawl.remote(url, 0, args.max_depth) for url in args.initial_urls]
19+
) # Initiate the crawling remotely
2720

2821
# Wait for all tasks to complete
2922
print("Done crawling.")
3023
ray.shutdown()
3124

25+
26+
if __name__ == "__main__":
27+
parser = argparse.ArgumentParser(
28+
prog="Distributed NLP Web Crawler",
29+
description="This program can crawl the web and store text embeddings",
30+
)
31+
parser.add_argument(
32+
"-u",
33+
"--initial-urls",
34+
nargs="+",
35+
)
36+
parser.add_argument("-lm", "--language-model", default="bert-base-uncased")
37+
parser.add_argument("-m", "--max-depth", default=2)
38+
args = parser.parse_args()
39+
40+
main(args)

db/__init__.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1 @@
1-
from vector_db import VectorDBClient
2-
3-
4-
__all__ = ["VectorDBClient": VectorDBClient]
1+
from db.vector import * # noqa

db/constants.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# TODO: most variable here should be set in .env file
2+
MILVUS_HOST = "localhost"
3+
MILVUS_PORT = "19530"
4+
USER = "admin"
5+
PASSWORD = "admin"
6+
URI = f"http://{MILVUS_HOST}:{MILVUS_PORT}"
7+
COLLECTION_NAME = "web_crawler_data"
8+
INDEX_PARAM = {
9+
"metric_type":"L2",
10+
"index_type":"IVF_FLAT",
11+
"params":{"nlist":1024}
12+
}
13+
DB_COLS = {
14+
"URL": "url",
15+
"TEXT": "text",
16+
"EMBED": "embeddings",
17+
}

db/vector.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
2+
connections, utility)
3+
4+
from db.constants import (COLLECTION_NAME, DB_COLS, INDEX_PARAM, MILVUS_HOST,
5+
MILVUS_PORT)
6+
7+
8+
class VectorDBClient:
9+
"""Vector database client."""
10+
11+
def __init__(self, embedding_size: int, batch_size: int):
12+
# Unpack parameters
13+
self.batch_size = batch_size
14+
self.embedding_size = embedding_size
15+
16+
self._setup_db_connection()
17+
18+
self._reset_batch()
19+
20+
def _setup_db_connection(self):
21+
"""Setup the Milvus connection."""
22+
# connections.connect(host=MILVUS_HOST, port=MILVUS_PORT, password=PASSWORD, secure=True)
23+
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
24+
25+
self.schema = [
26+
FieldSchema(name=DB_COLS["URL"], dtype=DataType.VARCHAR, is_primary=True, max_length=1024),
27+
FieldSchema(name=DB_COLS["TEXT"], dtype=DataType.VARCHAR, max_length=1024),
28+
FieldSchema(name=DB_COLS["EMBED"], dtype=DataType.FLOAT_VECTOR, dim=self.embedding_size),
29+
]
30+
if not utility.has_collection(COLLECTION_NAME):
31+
col_schema = CollectionSchema(fields=self.schema)
32+
self.collection = Collection(name=COLLECTION_NAME, schema=col_schema)
33+
assert utility.has_collection(COLLECTION_NAME), " it could not be created"
34+
self.collection.create_index(field_name=DB_COLS["EMBED"], index_params=INDEX_PARAM)
35+
print("It was created successfully")
36+
else:
37+
self.collection = Collection(name=COLLECTION_NAME)
38+
39+
self.collection.load()
40+
41+
def _reset_batch(self):
42+
"""Reset the batch."""
43+
self.batch = [[], [], []] # url, text, embeddings
44+
45+
def _submit_batch(self):
46+
"""Submit the batch to Milvus."""
47+
self.collection.insert(COLLECTION_NAME, self._batch)
48+
self.collection.flush()
49+
self._reset_batch()
50+
51+
def insert(self, url, text, embeddings):
52+
"""Insert a crawled entry in milvus.
53+
54+
NOTE: the method will only insert the entries in the DB
55+
if the batch size is reached.
56+
57+
Parameters
58+
----------
59+
url : str
60+
The URL of the crawled page.
61+
text : str
62+
The text of the crawled page.
63+
embeddings : numpy.ndarray
64+
The embeddings of the crawled page.
65+
"""
66+
# Insert data into Milvus
67+
self.batch.append([url], [text], [embeddings.tolist()])
68+
69+
if len(self.batch[0]) >= self.batch_size:
70+
self._submit_batch()
71+
72+
def close(self):
73+
"""Close the Milvus connection."""
74+
if len(self.batch[0]):
75+
self._submit_batch()
76+
self.milvus_client.close()

db/vector_db.py

-26
This file was deleted.

language_model/__init__.py

-7
This file was deleted.

language_model/base.py

-30
This file was deleted.

language_model/bert.py

-9
This file was deleted.

llm/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from llm.bert import BaseLanguageModel
2+
3+
MODEL_REGISTRY = {"bert-base-uncased": BaseLanguageModel}

llm/base.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
3+
4+
class BaseLanguageModel:
5+
def __init__(
6+
self,
7+
embedding_aggr_fn_name: str = "mean",
8+
):
9+
self.embedding_aggr_fn = embedding_aggr_fn_name
10+
11+
@property
12+
def max_token_length(self):
13+
raise NotImplementedError
14+
15+
def _chunk_tokens(self, tokens):
16+
chunks = []
17+
for i in range(0, len(tokens), self.max_token_length):
18+
chunks.append(tokens[i : i + self.max_token_length])
19+
return chunks
20+
21+
def _aggregate_embeddings(self, embeddings, dim=0):
22+
if self.embedding_aggr_fn == "mean":
23+
embedding_aggr = embeddings.mean(dim=dim)
24+
elif self.embedding_aggr_fn == "max":
25+
embedding_aggr = embeddings.max(dim=dim)
26+
else:
27+
raise NotImplementedError(
28+
"The embedding aggregation function `{self.embedding_aggr_fn}` is not allowed"
29+
)
30+
return embedding_aggr
31+
32+
def text_to_embedding(self, text):
33+
# Tokenize the text
34+
tokens = self.tokenizer.encode(text, add_special_tokens=True)
35+
36+
# Preprocess text just in case the number of tokens is too large
37+
chunks = self._chunk_tokens(tokens)
38+
39+
embedding_per_chunk = []
40+
for chunk in chunks:
41+
# Convert tokens to PyTorch tensors
42+
input_ids = torch.tensor(chunk).unsqueeze(0) # Batch size of 1
43+
# Get BERT model embeddings
44+
with torch.no_grad():
45+
outputs = self.model(input_ids).last_hidden_state.squeeze(0)
46+
# Extract embeddings from the model output
47+
embeddings = self._aggregate_embeddings(outputs)
48+
# Convert embeddings to NumPy array
49+
embedding_per_chunk.append(embeddings)
50+
51+
# Stacking the embeddings all from chunks
52+
embedding_stack = torch.stack(embedding_per_chunk)
53+
embedding_output = self._aggregate_embeddings(embedding_stack).squeeze(0)
54+
55+
return embedding_output.numpy()

0 commit comments

Comments
 (0)