Skip to content

Commit

Permalink
Fixed flow of LLM optimizer, updated SIR demo with description
Browse files Browse the repository at this point in the history
  • Loading branch information
bronevet-abc committed Jan 6, 2025
1 parent 376f3dc commit fe30142
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 58 deletions.
26 changes: 23 additions & 3 deletions py/sight/demo/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def driver(sight: Sight) -> None:
# data_structures.log_var('R', R, sight)
action = decision.decision_point('init', sight)
print('dt=%s, action=%s' % (dt, action))
I, R = 1, 0
I = int(action['I0'])
R = 0
S = int(action['population']) - I - R

hist = []
Expand Down Expand Up @@ -92,8 +93,18 @@ def main(argv: Sequence[str]) -> None:
decision.run(
driver_fn=driver,
description='''
I am building an SIR model to analyze the progress of Measles infections in Los Angeles during the summer of 2020.
I need to configure this model's parameters based on data from the Los Angeles County Department of Public Health.
The SIR model is one of the simplest compartmental models, and many models are derivatives of this basic form. The model consists of three compartments:
S: The number of susceptible individuals. When a susceptible and an infectious individual come into "infectious contact", the susceptible individual contracts the disease and transitions to the infectious compartment.
I: The number of infectious individuals. These are individuals who have been infected and are capable of infecting susceptible individuals.
R for the number of removed (and immune) or deceased individuals. These are individuals who have been infected and have either recovered from the disease and entered the removed compartment, or died. It is assumed that the number of deaths is negligible with respect to the total population. This compartment may also be called "recovered" or "resistant".
This model is reasonably predictive[11] for infectious diseases that are transmitted from human to human, and where recovery confers lasting resistance, such as measles, mumps, and rubella.
These variables (S, I, and R) represent the number of people in each compartment at a particular time. To represent that the number of susceptible, infectious, and removed individuals may vary over time (even if the total population size remains constant), we make the precise numbers a function of t (time): S(t), I(t), and R(t). For a specific disease in a specific population, these functions may be worked out in order to predict possible outbreaks and bring them under control.[11] Note that in the SIR model, R(0) and R_{0}} are different quantities – the former describes the number of recovered at t = 0 whereas the latter describes the ratio between the frequency of contacts to the frequency of recovery.
As implied by the variable function of t, the model is dynamic in that the numbers in each compartment may fluctuate over time. The importance of this dynamic aspect is most obvious in an endemic disease with a short infectious period, such as measles in the UK prior to the introduction of a vaccine in 1968. Such diseases tend to occur in cycles of outbreaks due to the variation in number of susceptibles (S(t)) over time. During an epidemic, the number of susceptible individuals falls rapidly as more of them are infected and thus enter the infectious and removed compartments. The disease cannot break out again until the number of susceptibles has built back up, e.g. as a result of offspring being born into the susceptible compartment.[citation needed]
Each member of the population typically progresses from susceptible to infectious to recovered. This can be shown as a flow diagram in which the boxes represent the different compartments and the arrows the transition between compartments.
''',
state_attrs={},
action_attrs={
Expand Down Expand Up @@ -131,6 +142,15 @@ def main(argv: Sequence[str]) -> None:
continuous_prob_dist=sight_pb2.ContinuousProbDist(
uniform=sight_pb2.ContinuousProbDist.Uniform(
min_val=0, max_val=.2))),
'I0':
sight_pb2.DecisionConfigurationStart.AttrProps(
min_value=0,
max_value=1000,
description=
'The number of individuals infected at the start of the epidemic.',
discrete_prob_dist=sight_pb2.DiscreteProbDist(
uniform=sight_pb2.DiscreteProbDist.Uniform(
min_val=0, max_val=1000))),
},
outcome_attrs={
'S':
Expand Down
7 changes: 0 additions & 7 deletions py/sight/widgets/decision/converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,6 @@

_LOG_ID = flags.DEFINE_string(
'log_id', None, 'ID of the Sight log that tracks this execution.')
_DEPLOYMENT_MODE = flags.DEFINE_enum(
'deployment_mode',
None,
['distributed', 'dsub_local', 'docker_local', 'local', 'worker_mode'],
('The procedure to use when training a model to drive applications that '
'use the Decision API.'),
)


def main(argv: Sequence[str]) -> None:
Expand Down
5 changes: 2 additions & 3 deletions py/sight/widgets/decision/decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,10 +1431,9 @@ def _handle_optimizer_finalize(sight: Any, req: Any) -> None:
choice_params.CopyFrom(convert_dict_to_proto(dict=msg.action_params))
decision_message.decision_point.choice_params.CopyFrom(choice_params)

logging.info('decision_message=%s', decision_message)
# logging.info('decision_message=%s', decision_message)
req.decision_messages.append(decision_message)
logging.info('req=%s', req)
logging.info('optimizer_obj=%s', optimizer_obj)
logging.info('Finalize req=%s', req)

# clearing the cached
cached_messages_obj.clear()
Expand Down
3 changes: 3 additions & 0 deletions py/sight/widgets/decision/llm_optimizer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sight.proto import sight_pb2
from sight.widgets.decision.optimizer_client import OptimizerClient
from sight_service.proto import service_pb2
from sight_service.shared_batch_messages import CachedBatchMessages


class LLMOptimizerClient(OptimizerClient):
Expand Down Expand Up @@ -49,8 +50,10 @@ def __init__(self, llm_name: str, description: str, sight):

self._description = description

self.cache: CachedBatchMessages = CachedBatchMessages()
self._sight = sight
self._worker_id = None


@override
def create_config(self) -> sight_pb2.DecisionConfigurationStart.ChoiceConfig:
Expand Down
112 changes: 67 additions & 45 deletions sight_service/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from concurrent import futures
import json
import os
import random
import threading
from typing import Any, Dict, List, Optional, Tuple
Expand All @@ -27,6 +28,7 @@
import requests
from sight.proto import sight_pb2
from sight.utils.proto_conversion import convert_dict_to_proto
from sight.utils.proto_conversion import convert_proto_to_dict
from sight_service.bayesian_opt import BayesianOpt
from sight_service.optimizer_instance import OptimizerInstance
from sight_service.proto import service_pb2
Expand Down Expand Up @@ -191,11 +193,11 @@ def _filtered_history(self, include_example_action: bool) -> List[Any]:
# if include_example_action and len(ordered_history) == 0:
# ordered_history.append(self._random_event())

logging.info(
'ordered_history[#%d]=%s',
len(ordered_history),
ordered_history,
)
# logging.info(
# 'ordered_history[#%d]=%s',
# len(ordered_history),
# ordered_history,
# )
# if worker_id is None:
if len(self._history) == 0:
return ordered_history
Expand Down Expand Up @@ -236,18 +238,18 @@ def _history_to_text(self, include_example_action: bool = True) -> str:
t = ''
last_outcome = None
hist = self._filtered_history(include_example_action)
logging.info(
'_history_to_text() include_example_action=%s hist=%s',
include_example_action,
hist,
)
# logging.info(
# '_history_to_text() include_example_action=%s hist=%s',
# include_example_action,
# hist,
# )
# if include_example_action and (
# len(hist) == 0 or (len(hist) == 1 and hist[0]['outcome'] is None)
# ):
# logging.info('_history_to_text() Adding random_event')
# t += self._hist_event_to_text(self._random_event(), None, False)
for i, event in enumerate(hist):
logging.info('_history_to_text event=%s', event)
# logging.info('_history_to_text event=%s', event)
event_text, last_outcome = self._hist_event_to_text(
event, last_outcome, i == len(hist) - 1)
t += event_text
Expand Down Expand Up @@ -292,12 +294,14 @@ def _history_to_chat(
' This is a similar outcome to the last time.\n')
return chat

def _params_to_dict(self, dp: sight_pb2) -> Dict[str, float]:
"""Returns the dict representation of a DecisionParams proto"""
d = {}
for a in dp:
d[a.key] = a.value.double_value
return d
# def _params_to_dict(self, dp: sight_pb2.DecisionParam) -> Dict[str, float]:
# """Returns the dict representation of a DecisionParams proto"""
# d = {}
# logging.info('params_to_dict() dp.params=%s', dp.params)
# for a in dp.params:
# logging.info('params_to_dict() a=%s', a)
# d[a.key] = a.value.double_value
# return d

def _get_creds(self) -> Any:
creds, project = google.auth.default()
Expand Down Expand Up @@ -451,12 +455,12 @@ def _ask_gemini_pro(self, prompt) -> str:
}),
headers=self._get_req_headers(),
).json()
logging.info('response=%s', response)
# logging.info('response=%s', response)
if len(response) == 0:
continue
text = ''
for r in response:
if 'parts' in r['candidates'][0]['content']:
if 'content' in r['candidates'][0] and 'parts' in r['candidates'][0]['content']:
text += r['candidates'][0]['content']['parts'][0]['text']
text = text.strip()
if text == '':
Expand Down Expand Up @@ -541,7 +545,7 @@ def decision_point(

if len(self._history) > 0 and 'outcome' not in self._history[0]:
if len(request.decision_outcome.outcome_params) > 0:
self._history[-1]['outcome'] = self._params_to_dict(
self._history[-1]['outcome'] = convert_proto_to_dict(
request.decision_point.outcome_params)
else:
self._history[-1]['outcome'] = request.decision_outcome.reward
Expand All @@ -553,7 +557,7 @@ def decision_point(
# ]) + '}\n'
# self.script += 'Decision Action (json format):\n'
self._history.append({
'state': self._params_to_dict(request.decision_point.state_params),
'state': convert_proto_to_dict(request.decision_point.state_params),
'action': None,
'outcome': None,
})
Expand Down Expand Up @@ -592,16 +596,33 @@ def decision_point(
# ]) + '}\n'

for key, value in self._history[-1]['action'].items():
a = dp_response.action.add()
a.key = key
a.value.double_value = float(value)
# a = dp_response.action.add()
# a.key = key
# a.value.double_value = float(value)
dp_response.action.params[key].CopyFrom(sight_pb2.Value(double_value=float(value),
sub_type = sight_pb2.Value.ST_DOUBLE))

self._num_decision_points += 1

self._lock.release()
dp_response.action_type = (
service_pb2.DecisionPointResponse.ActionType.AT_ACT)
return dp_response


@overrides
def WorkerAlive(
self, request: service_pb2.WorkerAliveRequest
) -> service_pb2.WorkerAliveResponse:
method_name = "WorkerAlive"
logging.debug(">>>> In %s of %s", method_name, __file__)
response = service_pb2.WorkerAliveResponse()
response.status_type = service_pb2.WorkerAliveResponse.StatusType.ST_ACT
decision_message = response.decision_messages.add()
decision_message.action_id = 1
logging.info("worker_alive_status is %s", response.status_type)
logging.debug("<<<< Out %s of %s", method_name, __file__)
return response

@overrides
def finalize_episode(
Expand All @@ -610,28 +631,29 @@ def finalize_episode(
self._lock.acquire()

logging.info('FinalizeEpisode request=%s', request)
if len(request.decision_outcome.outcome_params) > 0:
self._history[-1]['outcome'] = self._params_to_dict(
request.decision_outcome.outcome_params)
else:
self._history[-1]['outcome'] = request.decision_outcome.reward
# self.last_outcome = self._history[-1]['outcome']

logging.info('self._history[-1]=%s', self._history[-1])
request.decision_point.choice_params.CopyFrom(
convert_dict_to_proto(dict=self._history[-1]['action']))
self._bayesian_opt.finalize_episode(request)
for i in range(len(request.decision_messages)):
if len(request.decision_messages[i].decision_outcome.outcome_params.params) > 0:
self._history[-1]['outcome'] = convert_proto_to_dict(
request.decision_messages[i].decision_outcome.outcome_params)
else:
self._history[-1]['outcome'] = request.decision_messages[i].decision_outcome.reward
# self.last_outcome = self._history[-1]['outcome']

if (self._llm_config.goal ==
sight_pb2.DecisionConfigurationStart.LLMConfig.LLMGoal.LM_INTERACTIVE):
# If there are no outstanding acitions, ask the LLM whether the user's
# question can be answered via the already-completed model runs.
if len(self._actions_to_do) == 0:
can_respond_to_question, response = self._is_done(request.worker_id)
self._response_ready = can_respond_to_question
if self._response_ready:
self._waiting_on_tell = True
self._response_for_listen = response
logging.info('self._history[-1]=%s', self._history[-1])
request.decision_messages[i].decision_point.choice_params.CopyFrom(
convert_dict_to_proto(dict=self._history[-1]['action']))
self._bayesian_opt.finalize_episode(request)

if (self._llm_config.goal ==
sight_pb2.DecisionConfigurationStart.LLMConfig.LLMGoal.LM_INTERACTIVE):
# If there are no outstanding acitions, ask the LLM whether the user's
# question can be answered via the already-completed model runs.
if len(self._actions_to_do) == 0:
can_respond_to_question, response = self._is_done(request.worker_id)
self._response_ready = can_respond_to_question
if self._response_ready:
self._waiting_on_tell = True
self._response_for_listen = response
self._lock.release()

logging.info(
Expand Down

0 comments on commit fe30142

Please sign in to comment.