|
| 1 | +import json |
| 2 | +import re |
| 3 | +import types |
| 4 | + |
| 5 | +from google.cloud import pubsub_v1 |
| 6 | +from google.api_core.exceptions import ClientError |
| 7 | +from google.pubsub_v1.types import PullRequest |
| 8 | +from google.pubsub_v1.types import AcknowledgeRequest |
| 9 | +from .secrets import google_credentials |
| 10 | + |
| 11 | +from .lib import toiter, sip, jsonify |
| 12 | + |
| 13 | + |
| 14 | +import tenacity |
| 15 | + |
| 16 | +PUBSUB_BATCH_SIZE = 10 # send_message_batch's max batch size is 10 |
| 17 | + |
| 18 | + |
| 19 | +class ClientSideError(Exception): |
| 20 | + pass |
| 21 | + |
| 22 | + |
| 23 | +retry = tenacity.retry( |
| 24 | + reraise=True, |
| 25 | + stop=tenacity.stop_after_attempt(4), |
| 26 | + wait=tenacity.wait_random_exponential(0.5, 60.0), |
| 27 | + retry=tenacity.retry_if_not_exception_type(ClientSideError), |
| 28 | +) |
| 29 | + |
| 30 | + |
| 31 | +class PubSubTaskQueueAPI(object): |
| 32 | + def __init__(self, qurl, **kwargs): |
| 33 | + """ |
| 34 | + qurl: a topic or subscription location |
| 35 | + conforms to this format projects/{project_id}/topics/{topic_id}/subscriptions/{subscription_id} |
| 36 | + kwargs: Keywords for the underlying boto3.client constructor, other than `service_name`, |
| 37 | + `region_name`, `aws_secret_access_key`, or `aws_access_key_id`. |
| 38 | + """ |
| 39 | + pattern = r"^projects/(?P<project_id>[\w\d-]+)/topics/(?P<topic_id>[\w\d-]+)/subscriptions/(?P<subscription_id>[\w\d-]+)$" |
| 40 | + matches = re.match(pattern, qurl) |
| 41 | + if matches is None: |
| 42 | + raise ValueError( |
| 43 | + "qurl does not conform to the required format (projects/{project_id}/topics/{topic_id}/subscriptions/{subscription_id})" |
| 44 | + ) |
| 45 | + |
| 46 | + matches = re.search(r"projects/([\w\d-]+)/", qurl) |
| 47 | + self.project_id = matches.group(1) |
| 48 | + |
| 49 | + matches = re.search(r"topics/([\w\d-]+)", qurl) |
| 50 | + self.topic_id = matches.group(1) |
| 51 | + |
| 52 | + matches = re.search(r"subscriptions/([\w\d-]+)", qurl) |
| 53 | + self.subscription_id = matches.group(1) |
| 54 | + |
| 55 | + project_name, credentials = google_credentials() |
| 56 | + |
| 57 | + self.subscriber = pubsub_v1.SubscriberClient(credentials=credentials) |
| 58 | + self.publisher = pubsub_v1.PublisherClient(credentials=credentials) |
| 59 | + self._topic_path = self.publisher.topic_path(self.project_id, self.topic_id) |
| 60 | + self._subscription_path = self.subscriber.subscription_path( |
| 61 | + self.project_id, self.subscription_id |
| 62 | + ) |
| 63 | + |
| 64 | + self.batch_size = PUBSUB_BATCH_SIZE |
| 65 | + |
| 66 | + @property |
| 67 | + def enqueued(self): |
| 68 | + raise float("Nan") |
| 69 | + |
| 70 | + @property |
| 71 | + def inserted(self): |
| 72 | + return float("NaN") |
| 73 | + |
| 74 | + @property |
| 75 | + def completed(self): |
| 76 | + return float("NaN") |
| 77 | + |
| 78 | + @property |
| 79 | + def leased(self): |
| 80 | + return float("NaN") |
| 81 | + |
| 82 | + def is_empty(self): |
| 83 | + return self.enqueued == 0 |
| 84 | + |
| 85 | + @retry |
| 86 | + def insert(self, tasks, delay_seconds=0): |
| 87 | + tasks = toiter(tasks) |
| 88 | + |
| 89 | + def publish_batch(batch): |
| 90 | + if not batch: |
| 91 | + return 0 |
| 92 | + |
| 93 | + futures = [] |
| 94 | + for task in batch: |
| 95 | + data = jsonify(task).encode("utf-8") |
| 96 | + future = self.publisher.publish(self._topic_path, data) |
| 97 | + futures.append(future) |
| 98 | + |
| 99 | + # Wait for all messages to be published |
| 100 | + for future in futures: |
| 101 | + try: |
| 102 | + # Blocks until the message is published |
| 103 | + future.result() |
| 104 | + except Exception as e: |
| 105 | + raise ClientError(e) |
| 106 | + |
| 107 | + return len(futures) |
| 108 | + |
| 109 | + total = 0 |
| 110 | + |
| 111 | + # send_message_batch's max batch size is 10 |
| 112 | + for batch in sip(tasks, self.batch_size): |
| 113 | + if len(batch) == 0: |
| 114 | + break |
| 115 | + total += publish_batch(batch) |
| 116 | + |
| 117 | + return total |
| 118 | + |
| 119 | + def add_insert_count(self, ct): |
| 120 | + pass |
| 121 | + |
| 122 | + def rezero(self): |
| 123 | + pass |
| 124 | + |
| 125 | + @retry |
| 126 | + def renew_lease(self, task, seconds): |
| 127 | + self.subscriber.modify_ack_deadline( |
| 128 | + self._subscription_path, |
| 129 | + [task.id], |
| 130 | + seconds, |
| 131 | + ) |
| 132 | + |
| 133 | + def cancel_lease(self, task): |
| 134 | + self.subscriber.acknowledge(self._subscription_path, [task.id]) |
| 135 | + |
| 136 | + def release_all(self): |
| 137 | + raise NotImplementedError() |
| 138 | + |
| 139 | + def lease(self, seconds, num_tasks=1, wait_sec=20): |
| 140 | + # Pull messages from the subscription |
| 141 | + request = PullRequest( |
| 142 | + subscription=self._subscription_path, max_messages=num_tasks |
| 143 | + ) |
| 144 | + response = self.subscriber.pull(request) |
| 145 | + |
| 146 | + tasks = [] |
| 147 | + for received_message in response.received_messages: |
| 148 | + # Load the message data as JSON |
| 149 | + task = json.loads(received_message.message.data.decode("utf-8")) |
| 150 | + # Store the acknowledgement ID in the task |
| 151 | + task["id"] = received_message.ack_id |
| 152 | + tasks.append(task) |
| 153 | + |
| 154 | + return tasks |
| 155 | + |
| 156 | + def delete(self, task): |
| 157 | + if isinstance(task, str): |
| 158 | + ack_id = task |
| 159 | + else: |
| 160 | + try: |
| 161 | + ack_id = task._id |
| 162 | + except AttributeError: |
| 163 | + ack_id = task["id"] |
| 164 | + request = AcknowledgeRequest( |
| 165 | + subscription=self._subscription_path, ack_ids=[ack_id] |
| 166 | + ) |
| 167 | + self.subscriber.acknowledge(request=request) |
| 168 | + return 1 |
| 169 | + |
| 170 | + def tally(self): |
| 171 | + pass |
| 172 | + |
| 173 | + def purge(self, native=False): |
| 174 | + while True: |
| 175 | + # Pull messages from the subscription |
| 176 | + response = self.subscriber.pull( |
| 177 | + self._subscription_path, max_messages=self.batch_size |
| 178 | + ) |
| 179 | + |
| 180 | + if not response.received_messages: |
| 181 | + # No more messages, break the loop |
| 182 | + break |
| 183 | + |
| 184 | + # Acknowledge all received messages |
| 185 | + ack_ids = [msg.ack_id for msg in response.received_messages] |
| 186 | + request = AcknowledgeRequest( |
| 187 | + subscription=self._subscription_path, ack_ids=ack_ids |
| 188 | + ) |
| 189 | + self.subscriber.acknowledge(request=request) |
| 190 | + |
| 191 | + def __iter__(self): |
| 192 | + return iter(self.lease(num_tasks=10, seconds=0)) |
0 commit comments