Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: test change #4847

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ def create_client(
client: Instantiated Vertex AI Service client with optional overrides
"""
gapic_version = __version__
print("IN BRANCH CL PRINT -- PARKER")

if appended_gapic_version:
gapic_version = f"{gapic_version}+{appended_gapic_version}"
Expand Down Expand Up @@ -586,6 +587,7 @@ def create_client(
gapic_version=gapic_version,
user_agent=user_agent,
)
print("Branch CL client info: " + str(client_info))

kwargs = {
"credentials": credentials or self.credentials,
Expand All @@ -598,6 +600,7 @@ def create_client(
),
"client_info": client_info,
}
print("Branch CL kwargs: " + str(kwargs))

# Do not pass "grpc", rely on gapic defaults unless "rest" is specified
if self._api_transport == "rest" and "Async" in client_class.__name__:
Expand All @@ -622,7 +625,9 @@ def create_client(
client = client_class(**kwargs)
# We only wrap the client if the request_metadata is set at the creation time.
if self._request_metadata:
print("Branch CL wrapping client because request metadata is set")
client = _ClientWrapperThatAddsDefaultMetadata(client)
print("Branch CL returning client: " + str(client))
return client

def _get_default_project_and_location(self) -> Tuple[str, str]:
Expand Down
22 changes: 22 additions & 0 deletions vertexai/prompts/_prompt_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,17 +493,28 @@ def _create_prompt_version_resource(

def _get_prompt_resource(prompt: Prompt, prompt_id: str) -> gca_dataset.Dataset:
"""Helper function to get a prompt resource from a prompt id."""
print("Branch CL _get_prompt_resource")
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}"
dataset = prompt._dataset_client.get_dataset(name=name)

print("Setting ._dataset_client.client_info.gapic_version")
prompt._dataset_client.client_info.gapic_version = prompt._dataset_client.client_info.gapic_version + "+prompt_management"
print(str(prompt._dataset_client.client_info.gapic_version))

print("Setting ._dataset_client.appended_gapic_version")
prompt._dataset_client.appended_gapic_version = prompt._dataset_client.appended_gapic_version + "+prompt_management"
print(str(prompt._dataset_client.appended_gapic_version))

return dataset


def _get_prompt_resource_from_version(
prompt: Prompt, prompt_id: str, version_id: str
) -> gca_dataset.Dataset:
"""Helper function to get a prompt resource from a prompt version id."""
print("Branch CL _get_prompt_resource_from_version")
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}/datasetVersions/{version_id}"
Expand All @@ -516,6 +527,14 @@ def _get_prompt_resource_from_version(
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}"
dataset = prompt._dataset_client.get_dataset(name=name)

print("Setting ._dataset_client.client_info.gapic_version")
prompt._dataset_client.client_info.gapic_version = prompt._dataset_client.client_info.gapic_version + "+prompt_management"
print(str(prompt._dataset_client.client_info.gapic_version))

print("Setting ._dataset_client.appended_gapic_version")
prompt._dataset_client.appended_gapic_version = prompt._dataset_client.appended_gapic_version + "+prompt_management"
print(str(prompt._dataset_client.appended_gapic_version))

# Step 3: Convert to DatasetVersion object to Dataset object
dataset = gca_dataset.Dataset(
name=name,
Expand Down Expand Up @@ -573,19 +592,22 @@ def get(prompt_id: str, version_id: Optional[str] = None) -> Prompt:
"""
prompt = Prompt()
if version_id:
print("Branch CL get prompt resource from version")
dataset = _get_prompt_resource_from_version(
prompt=prompt,
prompt_id=prompt_id,
version_id=version_id,
)
else:
print("Branch CL get prompt resource")
dataset = _get_prompt_resource(prompt=prompt, prompt_id=prompt_id)

# Remove etag to avoid error for repeated dataset updates
dataset.etag = None

prompt._dataset = dataset
prompt._version_id = version_id
prompt._used_prompt_management = True

dataset_dict = _proto_to_dict(dataset)

Expand Down
7 changes: 7 additions & 0 deletions vertexai/prompts/_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
self._prompt_name = None
self._version_id = None
self._version_name = None
self._used_prompt_management = None

self.prompt_data = prompt_data
self.variables = variables if variables else [{}]
Expand Down Expand Up @@ -610,6 +611,12 @@ def generate_content(
model = GenerativeModel(
model_name=model_name, system_instruction=system_instruction
)

if self._used_prompt_management:
# Want to update `appended_gapic_version` field here with the
# boolean value...
print("Branch CL generate_content AFTER _used_prompt_management")

return model.generate_content(
contents=contents,
generation_config=generation_config,
Expand Down
Loading