Skip to content

Commit 7b16220

Browse files
committed
Merge branch 'master' of https://github.com/shubham1172/VQA
2 parents fbc6da8 + 1a64423 commit 7b16220

File tree

4 files changed

+49
-8
lines changed

4 files changed

+49
-8
lines changed

modules/answer_generator.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
1+
from deeppavlov import build_model, configs
2+
from utils.nostderrout import nostderrout
3+
4+
15
class AnswerGenerator:
2-
@staticmethod
3-
def predict(comprehension, question):
6+
def __init__(self):
7+
"""
8+
create a model from pre-trained weights
9+
"""
10+
print("START")
11+
with nostderrout():
12+
self.model = build_model(configs.squad.squad)
13+
print("END")
14+
15+
def predict(self, comprehension, question):
416
"""
517
predict answer for a question based on the given data
618
:param comprehension: data to read for answering the question
719
:param question: question based on the comprehension
820
:return: answer for the question
921
"""
10-
pass
22+
return self.model([comprehension], [question])[0]
1123

12-
ag = AnswerGenerator()
13-
print(ag.predict("There are two people.", "How many people are there?"))

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ colorthief
99
torch
1010
torchvision
1111
num2words
12+
deeppavlov

run.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import cv2
22
import argparse
33
import knowledge_graph
4-
import matplotlib.pyplot as plt
54
from pprint import pprint
5+
from modules.answer_generator import AnswerGenerator
6+
from modules.paragraph_generator import ParagraphGenerator
67

78
parser = argparse.ArgumentParser()
89
parser.add_argument('-p', '--path', help='path of the input image', required=True)
@@ -13,5 +14,15 @@
1314

1415
knowledge, frame = knowledge_graph.create_knowledge_graph(image)
1516
pprint(knowledge)
16-
plt.imshow(frame)
17-
plt.show()
17+
18+
paragraph_generator = ParagraphGenerator()
19+
paragraph = paragraph_generator.generate(knowledge)
20+
21+
answer_generator = AnswerGenerator()
22+
print(paragraph)
23+
24+
while True:
25+
question = input().strip()
26+
if question == "":
27+
break
28+
print(answer_generator.predict(paragraph, question))

utils/nostderrout.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import contextlib
2+
import sys
3+
import os
4+
5+
6+
class DummyFile(object):
7+
def write(self, x): pass
8+
9+
10+
@contextlib.contextmanager
11+
def nostderrout():
12+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
13+
save_stdout = sys.stdout
14+
save_stderr = sys.stderr
15+
sys.stdout = DummyFile()
16+
sys.stderr = DummyFile()
17+
yield
18+
sys.stdout = save_stdout
19+
sys.stderr = save_stderr

0 commit comments

Comments
 (0)