-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain_ranker.sh
63 lines (51 loc) · 1.82 KB
/
train_ranker.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
if [[ $# -lt 2 ]]; then
echo "Usage: $0 <codex-model-name[davinci, ada_002]> <fold 0/1/2/3/4> [max_positive] [max_negative] [alpha]"
exit 1
fi
SCRIPT_PATH="$(dirname -- "${BASH_SOURCE[0]}")"
BASE_DIR=`realpath $SCRIPT_PATH/../`;
echo "PROJECT BASE DIRECTORY: $BASE_DIR";
export PYTHONPATH=${BASE_DIR}:${PYTHONPATH};
im=$1;
fold=$2;
mp=${3:-"2"};
mn=${4:-"2"};
alpha=${5:-"0"};
name="${im}/mp-${mp}_mn-${mn}_alpha-${alpha}/fold_${fold}";
DATA_BASE_DIR="${BASE_DIR}/data";
OUTPUT_DIR="${BASE_DIR}/models/ranker_result/${name}";
CONFIG_FILE="${BASE_DIR}/configs/ranker_config.json";
LOG_DIR="${OUTPUT_DIR}/logs";
mkdir -p ${LOG_DIR};
echo $DATA_BASE_DIR
if [[ $im == "ada_002" ]]; then
codex_model="text-embedding-ada-002"
elif [[ $im == "davinci" ]]; then
echo "davinci-similarity model is deprecated. Please use ada_002 instead." >&2;
exit 1;
codex_model="davinci-similarity"
else
echo "Invalid codex model name: $im" >&2;
exit 1
fi
embedding_path="${DATA_BASE_DIR}/embeddings/${im}.json";
if [[ ! -f $embedding_path ]]; then
echo "Embedding file not found: $embedding_path"
echo "Please run 'python ${DATA_BASE_DIR}/get_initial_embeddings.py' first!";
echo "Run 'python ${DATA_BASE_DIR}/get_initial_embeddings.py --help' for more information."
exit 1
fi
python $BASE_DIR/src/ranker/main.py \
--data_path ${DATA_BASE_DIR}/ranker_data/fold_${fold} \
--embedding_path $embedding_path \
--training_config ${CONFIG_FILE} \
--output_dir ${OUTPUT_DIR} \
--initial_model codex \
--codex_model $codex_model \
--max_positive_examples ${mp} \
--max_negative_examples ${mn} \
--alpha ${alpha} \
--do_train \
--data_cache_path ${OUTPUT_DIR}/data_cache 2>&1| tee ${LOG_DIR}/train_and_evaluate.log;