-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
genenric:sycl: Group Norm Forward #2733
Conversation
make test |
I'm holding off on rebasing onto the latest main(22/02/2025) to resolve the conflict as it is leading to unrelated build failures -
(The above build is on the latest main (22/02/2025) with gpu vendor as generic) |
My bad, somehow it slipped from me during the development, and Cuda-based build wasn't triggered by default for PR validation, I'll follow up on this. |
Thanks a lot @dzarukin , and no issues at all :) |
21ced41
to
0b960b9
Compare
@AD2605 we should also update the README file with group norm. |
Thanks for pointing that out, I forgot about. Will update the PR |
2a25d50
to
9b57921
Compare
9b57921
to
99cc959
Compare
make test |
Pre commit failures are unrelated to the changes in this PR |
* Copyright 2023-2024 Intel Corporation | ||
* Copyright 2024-2025 Codeplay Software Limited |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For new files, please just use the current year. Please also fix the other new files in the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
if (conf_.src_scaling) { | ||
// Only one scaling factor per tensor is allowed, | ||
// as per the spec. Scaling factor will also always be f32 as per spec. | ||
normalized_value *= load_float_value( | ||
data_type::f32, src_scale.get_pointer(), 0); | ||
} | ||
float prev_value = normalized_value; | ||
normalized_value = conf_.post_ops.apply( | ||
normalized_value, prev_value, po_args_, logical_index); | ||
if (conf_.dst_scaling) { | ||
// Only one scaling factor per tensor is allowed, | ||
// as per the spec. Scaling factor will also always be f32 as per spec. | ||
normalized_value *= (1.0f | ||
/ load_float_value(data_type::f32, | ||
dst_scale.get_pointer(), 0)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The common check seems to only cover scales mask, not scales datatype.
Could you at least add a proper check in the init()
for the scales datatype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the check
| primitive_attr_t::skip_mask_t::post_ops; | ||
VDISPATCH_GNORM( | ||
attr()->has_default_values(attr_mask), VERBOSE_UNSUPPORTED_ATTR); | ||
VDISPATCH_GNORM(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, it is preferable for each implementation to check for attributes they skipped default checking, to avoid issues when we extend functionality.
Could you add the extra restrictions from implementation here (for datatype and mask)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the check for both
99cc959
to
2e850eb
Compare
2e850eb
to
0e74961
Compare
@mgouicem please let me know in case you are happy with the current state of the PR and the resolutions to your comments. just wanted to check-in with you before the merge as the PR has received the required approvals |
0e74961
to
0b5863f
Compare
Description
Adds the implementation of group normalization for the SYCL backend.
General
make test
andmake test_benchdnn_*
) pass locally for each commit?