1
+ #!/usr/bin/env python
2
+
3
+ import streamlit as st
4
+ st .set_page_config (page_title = "Digit Recognition Tutorial" , page_icon = ":shark:" ,)
5
+ import os
6
+ import time
7
+ import glob
8
+ import os
9
+ from PIL import Image
10
+ from omegaconf import OmegaConf
11
+ import json
12
+ import time
13
+ from dotenv import load_dotenv
14
+ import wandb
15
+ from hydra .utils import instantiate
16
+ from nimrod .models .mlp import MLP_PL
17
+ import torch
18
+ import torchvision .transforms as transforms
19
+ from pathlib import Path
20
+ from matplotlib import pyplot as plt
21
+
22
+ load_dotenv ()
23
+
24
+ def predict (x , model ):
25
+ model .eval ()
26
+ with torch .no_grad ():
27
+ # model forward method calls mlp which is (B,C,W*H) unlike datamodule which is (B,C,W,H)
28
+ y_hat = model (x ).argmax (dim = 2 )
29
+ return y_hat
30
+
31
+ def process_image (image ):
32
+ tf = transforms .Compose ([transforms .ToTensor (), transforms .Grayscale (),transforms .Resize ((28 ,28 ))])
33
+ x = tf (image ).view (1 ,1 , 28 * 28 )
34
+ return x
35
+
36
+ def main ():
37
+ hide_st_style = """
38
+ <style>
39
+ #MainMenu {visibility: hidden;}
40
+ footer {visibility: hidden;}
41
+ header {visibility: hidden;}
42
+ </style>
43
+ """
44
+ st .header ("Digit recognizer" )
45
+ st .markdown (hide_st_style , unsafe_allow_html = True )
46
+ file_uploaded = st .file_uploader ("Choose File" , type = ["png" ,"jpg" ,"jpeg" ])
47
+ model = load_model ()
48
+
49
+ if file_uploaded is not None :
50
+ image = Image .open (file_uploaded )
51
+ st .image (image , width = 250 )
52
+ x = process_image (image )
53
+ predictions = predict (x , model ).item ()
54
+ st .write ("**Recognized digit:**" , predictions )
55
+
56
+ @st .cache_resource ()
57
+ def load_model ():
58
+ cfg = OmegaConf .load ('../recipes/image/mnist/conf/train.yaml' )
59
+ model = instantiate (cfg .model )
60
+ run = wandb .init ()
61
+ artifact = run .use_artifact ('slegroux/MNIST-HP/model-0hfq6cko:v0' , type = 'model' )
62
+ artifact_dir = artifact .download ()
63
+ wandb .finish ()
64
+ model = MLP_PL .load_from_checkpoint (Path (artifact_dir ) / "model.ckpt" ).to (torch .device ('cpu' ))
65
+ return model
66
+
67
+ if __name__ == "__main__" :
68
+
69
+ main ()
0 commit comments