Skip to content

Commit d3b5d70

Browse files
committed
Enable quantizing local checkpoints in model release script
Summary: For torchao model release scripts, previously we only support quantizing models downloaded from hf directly (with a model id), this PR turns it off by default and allows users to quantize a local checkpoint Test Plan: cd .github/scripts/torchao_model_releases/ ./release.sh --model_id $LOCAL_MODEL_PATH --quants FP8 Reviewers: Subscribers: Tasks: Tags:
1 parent 27f4d75 commit d3b5d70

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

.github/scripts/torchao_model_releases/quantize_and_upload.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def _untie_weights_and_save_locally(model_id):
568568
"""
569569

570570

571-
def quantize_and_upload(model_id, quant):
571+
def quantize_and_upload(model_id, quant, push_to_hub):
572572
_int8_int4_linear_config = Int8DynamicActivationIntxWeightConfig(
573573
weight_dtype=torch.int4,
574574
weight_granularity=PerGroup(32),
@@ -579,7 +579,9 @@ def quantize_and_upload(model_id, quant):
579579
granularity=PerAxis(0),
580580
)
581581
quant_to_config = {
582-
"FP8": Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
582+
"FP8": Float8DynamicActivationFloat8WeightConfig(
583+
granularity=PerRow(), kernel_preference="torch"
584+
),
583585
"INT4": Int4WeightOnlyConfig(group_size=128),
584586
"INT8-INT4": ModuleFqnToConfig(
585587
{
@@ -657,9 +659,13 @@ def quantize_and_upload(model_id, quant):
657659
card = ModelCard(content)
658660

659661
# Push to hub
660-
quantized_model.push_to_hub(quantized_model_id, safe_serialization=False)
661-
tokenizer.push_to_hub(quantized_model_id)
662-
card.push_to_hub(quantized_model_id)
662+
if push_to_hub:
663+
quantized_model.push_to_hub(quantized_model_id, safe_serialization=False)
664+
tokenizer.push_to_hub(quantized_model_id)
665+
card.push_to_hub(quantized_model_id)
666+
else:
667+
quantized_model.save_pretrained(quantized_model_id, safe_serialization=False)
668+
tokenizer.save_pretrained(quantized_model_id)
663669

664670
# Manual Testing
665671
prompt = "Hey, are you conscious? Can you talk to me?"
@@ -700,5 +706,11 @@ def quantize_and_upload(model_id, quant):
700706
type=str,
701707
help="Quantization method. Options are FP8, INT4, INT8_INT4, AWQ-INT4",
702708
)
709+
parser.add_argument(
710+
"--push_to_hub",
711+
action="store_true",
712+
default=False,
713+
help="Flag to indicate whether push to huggingface hub or not",
714+
)
703715
args = parser.parse_args()
704-
quantize_and_upload(args.model_id, args.quant)
716+
quantize_and_upload(args.model_id, args.quant, args.push_to_hub)

.github/scripts/torchao_model_releases/release.sh

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
# Default quantization options
1616
default_quants=("FP8" "INT4" "INT8-INT4")
17+
push_to_hub=""
1718
# Parse arguments
1819
while [[ $# -gt 0 ]]; do
1920
case "$1" in
@@ -29,6 +30,10 @@ while [[ $# -gt 0 ]]; do
2930
shift
3031
done
3132
;;
33+
--push_to_hub)
34+
push_to_hub="--push_to_hub"
35+
shift
36+
;;
3237
*)
3338
echo "Unknown option: $1"
3439
exit 1
@@ -38,14 +43,14 @@ done
3843
# Use default quants if none specified
3944
if [[ -z "$model_id" ]]; then
4045
echo "Error: --model_id is required"
41-
echo "Usage: $0 --model_id <model_id> [--quants <quant1> [quant2 ...]]"
46+
echo "Usage: $0 --model_id <model_id> [--quants <quant1> [quant2 ...]] [--upload_to_hub]"
4247
exit 1
4348
fi
4449
if [[ ${#quants[@]} -eq 0 ]]; then
4550
quants=("${default_quants[@]}")
4651
fi
4752
# Run the python command for each quantization option
4853
for quant in "${quants[@]}"; do
49-
echo "Running: python quantize_and_upload.py --model_id $model_id --quant $quant"
50-
python quantize_and_upload.py --model_id "$model_id" --quant "$quant"
54+
echo "Running: python quantize_and_upload.py --model_id $model_id --quant $quant $push_to_hub"
55+
python quantize_and_upload.py --model_id "$model_id" --quant "$quant" $push_to_hub
5156
done

0 commit comments

Comments
 (0)