Skip to content

Commit d248433

Browse files
committed
upt
1 parent 09abb3f commit d248433

File tree

6 files changed

+20
-9
lines changed

6 files changed

+20
-9
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,4 @@ docs/doctrees
5757
docs/doxygen_output
5858

5959
# django
60-
python/mrt/frontend/db.sqlite3
60+
python/mrt/web/db.sqlite3

python/mrt/V3/evaluate.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
"--iter-num", type=int, default=10).dest: (_pname, "ITER_NUM"),
3636
})
3737

38-
def evaluate(cm_cfg, pass_cfg):
38+
def evaluate(cm_cfg, pass_cfg, logger=None):
3939
model_dir = cm_cfg.MODEL_DIR
4040
model_name = cm_cfg.MODEL_NAME
4141
verbosity = cm_cfg.VERBOSITY
@@ -45,7 +45,8 @@ def evaluate(cm_cfg, pass_cfg):
4545
batch = pass_cfg.BATCH
4646

4747
model_prefix = get_model_prefix(model_dir, model_name)
48-
logger = get_logger(verbosity)
48+
if logger is None:
49+
logger = get_logger(verbosity)
4950
conf_quant_file = model_prefix + ".quantize.conf"
5051
check_file_existance(conf_quant_file, logger=logger)
5152
conf_map = load_conf(conf_quant_file, logger=logger)

python/mrt/V3/mrt_compile.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
default=default_device_ids).dest: (_cnode, "DEVICE_IDS"),
3636
})
3737

38-
def mrt_compile(cm_cfg, pass_cfg):
38+
def mrt_compile(cm_cfg, pass_cfg, logger=None):
3939
model_dir = cm_cfg.MODEL_DIR
4040
model_name = cm_cfg.MODEL_NAME
4141
verbosity = cm_cfg.VERBOSITY
@@ -45,7 +45,8 @@ def mrt_compile(cm_cfg, pass_cfg):
4545
batch = pass_cfg.BATCH
4646

4747
model_prefix = get_model_prefix(model_dir, model_name)
48-
logger = get_logger(verbosity)
48+
if logger is None:
49+
logger = get_logger(verbosity)
4950
conf_quant_file = model_prefix + ".quantize.conf"
5051
check_file_existance(conf_quant_file, logger=logger)
5152
conf_map = load_conf(conf_quant_file, logger=logger)

python/mrt/V3/quantize.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
"--oscale-maps", type=tuple, default=[]).dest: (_pname, "OSCALE_MAPS"),
4949
})
5050

51-
def quantize(cm_cfg, pass_cfg):
51+
def quantize(cm_cfg, pass_cfg, logger=None):
5252
model_dir = cm_cfg.MODEL_DIR
5353
model_name = cm_cfg.MODEL_NAME
5454
verbosity = cm_cfg.VERBOSITY
@@ -65,7 +65,8 @@ def quantize(cm_cfg, pass_cfg):
6565
oscale_maps = {opn1: opn2 for opn1, opn2 in pass_cfg.OSCALE_MAPS}
6666

6767
model_prefix = get_model_prefix(model_dir, model_name)
68-
logger = get_logger(verbosity)
68+
if logger is None:
69+
logger = get_logger(verbosity)
6970
conf_calib_file = model_prefix + ".calibrate.conf"
7071
check_file_existance(conf_calib_file, logger=logger)
7172
conf_map = load_conf(conf_calib_file, logger=logger)

python/mrt/web/web/urls.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@
2222
path('calibrate/', views.mrt_calibrate_log),
2323
path('quantize/', views.mrt_quantize_log),
2424
path('evaluate/', views.mrt_evaluate_log),
25+
path('compile/', views.mrt_compile_log),
2526
]

python/mrt/web/web/views.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from mrt.V3.calibrate import calibrate
1313
from mrt.V3.quantize import quantize
1414
from mrt.V3.evaluate import evaluate
15-
# from mrt.V3.mrt_compile import mrt_compile
15+
from mrt.V3.mrt_compile import mrt_compile
1616
from .log import get_logger
1717

1818
class Printer:
@@ -55,7 +55,7 @@ def start(self):
5555
yield f'{item}<br>'
5656
except Empty:
5757
pass
58-
yield 'End'
58+
yield '<br>***End***<br>'
5959
printer.clean(self.thread)
6060

6161
mrt_web_tmp_dir = os.path.expanduser("~/.mrt_web")
@@ -96,3 +96,10 @@ def mrt_evaluate_log(request):
9696
logger = get_logger(cm_cfg.VERBOSITY, printer)
9797
streamer = Streamer(evaluate, (cm_cfg, pass_cfg, logger))
9898
return StreamingHttpResponse(streamer.start())
99+
100+
def mrt_compile_log(request):
101+
yaml_file = os.path.expanduser("~/mrt_yaml_root/alexnet.yaml")
102+
cm_cfg, pass_cfg = get_cfg(yaml_file, "COMPILE")
103+
logger = get_logger(cm_cfg.VERBOSITY, printer)
104+
streamer = Streamer(mrt_compile, (cm_cfg, pass_cfg, logger))
105+
return StreamingHttpResponse(streamer.start())

0 commit comments

Comments
 (0)