File size: 6,291 Bytes
358d6dd
 
 
 
 
b780d19
 
358d6dd
e79ef22
b780d19
358d6dd
49ba52a
b780d19
 
 
358d6dd
49ba52a
358d6dd
 
 
 
b780d19
49ba52a
358d6dd
 
b780d19
 
49ba52a
358d6dd
 
b780d19
358d6dd
b780d19
49ba52a
 
358d6dd
b780d19
49ba52a
 
 
 
 
 
358d6dd
 
49ba52a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358d6dd
49ba52a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b780d19
49ba52a
 
 
 
 
 
 
2e62d47
49ba52a
 
 
dfca30c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import requests
import json
import gradio as gr
import pdfplumber
import pandas as pd
from datetime import datetime
from google.oauth2.service_account import Credentials
from cnocr import CnOcr
import gspread
from sentence_transformers import SentenceTransformer, models, util

# Load credentials for Google Sheets
credentials = Credentials.from_service_account_file("credentials.json", scopes=["https://www.googleapis.com/auth/spreadsheets"])
client = gspread.authorize(credentials)
sheet = client.open_by_url("https://docs.google.com/spreadsheets/d/16H4M-8hHdOhI68vDIsDFT6T2xcGEvm0A7o5uFlmrzrQ/edit?usp=sharing").sheet1

# Initialize models and utilities
word_embedding_model = models.Transformer('sentence-transformers/all-MiniLM-L6-v2', do_lower_case=True)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='cls')
embedder = SentenceTransformer(modules=[word_embedding_model, pooling_model])
ocr = CnOcr()

# API URLs and headers
chat_url = 'https://Raghav001-API.hf.space/chatpdf'
chat_emd = 'https://Raghav001-API.hf.space/embedd'
headers = {'Content-Type': 'application/json'}

# Global variables
history_max_len = 500
all_max_len = 3000
bot = []

def record_to_sheet(timestamp, user_input, answer):
    row = [timestamp, user_input, answer]
    sheet.append_row(row)

def doc_emb(doc):
    texts = doc.split('\n')
    emb_list = embedder.encode(texts)
    print('emb_list', emb_list)
    print('\n'.join(texts))
    gr.Textbox.update(value="")
    return texts, emb_list, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Markdown.update(value="""success! Let's talk"""), gr.Chatbot.update(visible=True)

def get_response(msg, bot, doc_text_list, doc_embeddings):
    gr.Textbox.update(value="")
    now_len = len(msg)
    req_json = {'question': msg}
    his_bg = -1
    for i in range(len(bot) - 1, -1, -1):
        if now_len + len(bot[i][0]) + len(bot[i][1]) > history_max_len:
            break
        now_len += len(bot[i][0]) + len(bot[i][1])
        his_bg = i
    req_json['history'] = [] if his_bg == -1 else bot[his_bg:]
    query_embedding = embedder.encode([msg])
    cos_scores = util.cos_sim(query_embedding, doc_embeddings)[0]
    score_index = [[score, index] for score, index in zip(cos_scores, [i for i in range(len(cos_scores))])]
    score_index.sort(key=lambda x: x[0], reverse=True)
    print('score_index:\n', score_index)
    print('doc_emb_state', doc_emb_state)
    index_set, sub_doc_list = set(), []
    for s_i in score_index:
        doc = doc_text_list[s_i[1]]
        if now_len + len(doc) > all_max_len:
            break
        index_set.add(s_i[1])
        now_len += len(doc)
        # Maybe the paragraph is truncated wrong, so add the upper and lower paragraphs
        if s_i[1] > 0 and s_i[1] - 1 not in index_set:
            doc = doc_text_list[s_i[1]-1]
            if now_len + len(doc) > all_max_len:
                break
            index_set.add(s_i[1]-1)
            now_len += len(doc)
        if s_i[1] + 1 < len(doc_text_list) and s_i[1] + 1 not in index_set:
            doc = doc_text_list[s_i[1]+1]
            if now_len + len(doc) > all_max_len:
                break
            index_set.add(s_i[1]+1)
            now_len += len(doc)

    index_list = list(index_set)
    index_list.sort()
    for i in index_list:
        sub_doc_list.append(doc_text_list[i])
    req_json['doc'] = '' if len(sub_doc_list) == 0 else '\n'.join(sub_doc_list)
    data = {"content": json.dumps(req_json)}
    print('data:\n', req_json)
    result = requests.post(url=chat_url, data=json.dumps(data), headers=headers)
    res = result.json()['content']
    bot.append([msg, res])
    record_to_sheet(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), msg, res)
    return bot[max(0, len(bot) - 3):]

def up_file(fls):
    doc_text_list = []
    names = [str(i.name) for i in fls]
    pdf = []
    docs = []
    pptx = []

    for i in names:
        if i[-3:] == "pdf":
            pdf.append(i)
        elif i[-4:] == "docx":
            docs.append(i)
        else:
            pptx.append(i)

    # Pdf Extraction
    for idx, file in enumerate(pdf):
        with pdfplumber.open(file) as pdf:
            for i in range(len(pdf.pages)):
                page = pdf.pages[i]
                res_list = page.extract_text().split('\n')[:-1]

                for j in range(len(page.images)):
                    img = page.images[j]
                    file_name = f"{str(time.time())}-{str(i)}-{str(j)}.png"
                    with open(file_name, mode='wb') as f:
                        f.write(img['stream'].get_data())
                    try:
                        res = ocr.ocr(file_name)
                    except Exception as e:
                        res = []
                    if len(res) > 0:
                        res_list.append(' '.join([re['text'] for re in res]))

                tables = page.extract_tables()
                for table in tables:
                    df = pd.DataFrame(table[1:], columns=table[0])
                    try:
                        records = json.loads(df.to_json(orient="records"))
                        for rec in records:
                            res_list.append(json.dumps(rec))
                    except Exception as e:
                        res_list.append(str(df))

                doc_text_list += res_list

    # PPTX Extraction
    for i in pptx:
        loader = UnstructuredPowerPointLoader(i)
        data = loader.load()
        doc_text_list.append(data)

    # Doc Extraction
    for i in docs:
        loader = UnstructuredWordDocumentLoader(i)
        data = loader.load()
        doc_text_list.append(data)

    doc_text_list = [str(text).strip() for text in doc_text_list if len(str(text).strip()) > 0]
    return gr.Textbox.update(value='\n'.join(doc_text_list), visible=True), gr.Button.update(visible=True), gr.Markdown.update(value="Processing")

def launch_interface():
    with gr.Interface(
        fn=up_file,
        inputs="file",
        outputs=["text", "button", "markdown"],
        title="Document Chatbot",
        description="Upload a PDF contract to chat with the AI lawyer."
    ) as interface:
        interface.launch()

if __name__ == "__main__":
    launch_interface()