Skip to content

Commit 6b67c6f

Browse files
committed
Simple digit recognizer webapp; Fix #35
1 parent cf08118 commit 6b67c6f

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

server/run_st_app.sh

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/usr/bin/env bash
2+
3+
streamlit run st_app.py

server/st_app.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#!/usr/bin/env python
2+
3+
import streamlit as st
4+
st.set_page_config(page_title="Digit Recognition Tutorial", page_icon=":shark:",)
5+
import os
6+
import time
7+
import glob
8+
import os
9+
from PIL import Image
10+
from omegaconf import OmegaConf
11+
import json
12+
import time
13+
from dotenv import load_dotenv
14+
import wandb
15+
from hydra.utils import instantiate
16+
from nimrod.models.mlp import MLP_PL
17+
import torch
18+
import torchvision.transforms as transforms
19+
from pathlib import Path
20+
from matplotlib import pyplot as plt
21+
22+
load_dotenv()
23+
24+
def predict(x, model):
25+
model.eval()
26+
with torch.no_grad():
27+
# model forward method calls mlp which is (B,C,W*H) unlike datamodule which is (B,C,W,H)
28+
y_hat = model(x).argmax(dim=2)
29+
return y_hat
30+
31+
def process_image(image):
32+
tf = transforms.Compose([transforms.ToTensor(), transforms.Grayscale(),transforms.Resize((28,28))])
33+
x = tf(image).view(1,1, 28*28)
34+
return x
35+
36+
def main():
37+
hide_st_style = """
38+
<style>
39+
#MainMenu {visibility: hidden;}
40+
footer {visibility: hidden;}
41+
header {visibility: hidden;}
42+
</style>
43+
"""
44+
st.header("Digit recognizer")
45+
st.markdown(hide_st_style, unsafe_allow_html=True)
46+
file_uploaded = st.file_uploader("Choose File", type=["png","jpg","jpeg"])
47+
model = load_model()
48+
49+
if file_uploaded is not None:
50+
image = Image.open(file_uploaded)
51+
st.image(image, width=250)
52+
x = process_image(image)
53+
predictions = predict(x, model).item()
54+
st.write("**Recognized digit:**", predictions)
55+
56+
@st.cache_resource()
57+
def load_model():
58+
cfg = OmegaConf.load('../recipes/image/mnist/conf/train.yaml')
59+
model = instantiate(cfg.model)
60+
run = wandb.init()
61+
artifact = run.use_artifact('slegroux/MNIST-HP/model-0hfq6cko:v0', type='model')
62+
artifact_dir = artifact.download()
63+
wandb.finish()
64+
model = MLP_PL.load_from_checkpoint(Path(artifact_dir) / "model.ckpt").to(torch.device('cpu'))
65+
return model
66+
67+
if __name__ == "__main__":
68+
69+
main()

0 commit comments

Comments
 (0)