Skip to content

Commit dabded4

Browse files
authored
Release v0.2.1 with docs update (#74)
1 parent d08b71e commit dabded4

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

docs/online-inference-with-maxtext-engine.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ export ICI_AUTOREGRESSIVE_PARALLELISM=-1
108108
export ICI_TENSOR_PARALLELISM=1
109109
export SCAN_LAYERS=false
110110
export WEIGHT_DTYPE=bfloat16
111-
export PER_DEVICE_BATCH_SIZE=4
111+
export PER_DEVICE_BATCH_SIZE=11
112112
```
113113

114114
#### Create Llama2-7b environment variables for server flags
@@ -126,7 +126,7 @@ export ICI_AUTOREGRESSIVE_PARALLELISM=-1
126126
export ICI_TENSOR_PARALLELISM=1
127127
export SCAN_LAYERS=false
128128
export WEIGHT_DTYPE=bfloat16
129-
export PER_DEVICE_BATCH_SIZE=4
129+
export PER_DEVICE_BATCH_SIZE=11
130130
```
131131

132132
#### Create Llama2-13b environment variables for server flags
@@ -146,7 +146,7 @@ export ICI_AUTOREGRESSIVE_PARALLELISM=-1
146146
export ICI_TENSOR_PARALLELISM=1
147147
export SCAN_LAYERS=false
148148
export WEIGHT_DTYPE=bfloat16
149-
export PER_DEVICE_BATCH_SIZE=2
149+
export PER_DEVICE_BATCH_SIZE=4
150150
```
151151

152152
### Run the following command to start the JetStream MaxText server
@@ -182,7 +182,7 @@ python MaxText/maxengine_server.py \
182182
* ici\_autoregressive\_parallelism: The number of shards for autoregressive parallelism
183183
* ici\_tensor\_parallelism: The number of shards for tensor parallelism
184184
* weight\_dtype: Weight data type (e.g. bfloat16)
185-
* scan\_layers: Scan layers boolean flag
185+
* scan\_layers: Scan layers boolean flag (set to `false` for inference)
186186

187187
Note: these flags are from [MaxText config](https://github.com/google/maxtext/blob/f9e04cdc1eec74a0e648411857c09403c3358461/MaxText/configs/base.yml)
188188

@@ -200,7 +200,7 @@ python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokeniz
200200
The output will be similar to the following:
201201

202202
```bash
203-
Sending request to: dns:///[::1]:9000
203+
Sending request to: 0.0.0.0:9000
204204
Prompt: Today is a good day
205205
Response: to be a fan
206206
```
@@ -253,7 +253,7 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r
253253
# run benchmark with the downloaded dataset and the tokenizer in maxtext
254254
# You can control the qps by setting `--request-rate`, the default value is inf.
255255
python JetStream/benchmarks/benchmark_serving.py \
256-
--tokenizer maxtext/assets/tokenizer.gemma \
256+
--tokenizer maxtext/assets/tokenizer.gemma \
257257
--num-prompts 1000 \
258258
--dataset sharegpt \
259259
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \

jetstream/tools/maxtext/model_ckpt_conversion.sh

+4-4
Original file line numberDiff line numberDiff line change
@@ -71,17 +71,17 @@ else
7171
fi
7272
echo "Written MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}"
7373

74-
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory.
75-
export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items
74+
# We define `SCANNED_CKPT_PATH` to refer to the checkpoint subdirectory.
75+
export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items
7676

7777
# Covert MaxText compatible checkpoints to unscanned checkpoints.
78-
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
78+
# Note that the `SCANNED_CKPT_PATH` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
7979
export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx}
8080

8181
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \
8282
MaxText/configs/base.yml \
8383
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
84-
load_parameters_path=${CONVERTED_CHECKPOINT} \
84+
load_parameters_path=${SCANNED_CKPT_PATH} \
8585
run_name=${RUN_NAME} \
8686
model_name=${MODEL_NAME} \
8787
force_unroll=true

0 commit comments

Comments
 (0)