Skip to content

Commit 156b1c0

Browse files
author
Sean Heelan
committed
Fix bug in setting of openai parameters for OpenAI provider
1 parent 1565c80 commit 156b1c0

File tree

4 files changed

+30
-11
lines changed

4 files changed

+30
-11
lines changed
File renamed without changes.

.env.openai.example

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
GAI_API_TYPE="open_ai"
2+
GAI_API_KEY=""

perf-copilot.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,35 @@
4747
from dotenv import load_dotenv
4848
load_dotenv()
4949

50-
openai.api_key = os.environ["GAI_API_KEY"]
51-
openai.api_type = os.environ["GAI_API_TYPE"]
52-
openai.api_base = os.environ["GAI_API_BASE"]
53-
openai.api_version = os.environ["GAI_API_VERSION"]
54-
55-
if not openai.api_key:
56-
sys.stderr.write("You must set the GAI API key\n")
50+
api_type = api_key = api_base = api_version = None
51+
52+
try:
53+
api_type = os.environ["GAI_API_TYPE"]
54+
api_key = os.environ["GAI_API_KEY"]
55+
api_base = os.environ["GAI_API_BASE"]
56+
api_version = os.environ["GAI_API_VERSION"]
57+
except KeyError:
58+
pass
59+
60+
if not api_key or not api_type:
61+
sys.stderr.write("You must set the GAI API type and key\n")
5762
sys.exit(1)
5863

59-
if openai.api_type == "azure" and not (openai.api_base and openai.api_version):
60-
sys.stderr.write("Azure requires the API base and version to be set")
61-
sys.exit(1)
64+
openai.api_key = api_key
65+
openai.api_type = api_type
66+
67+
if api_type == "azure":
68+
if not (api_base and api_version):
69+
sys.stderr.write("Azure requires the API base and version to be set")
70+
sys.exit(1)
71+
openai.api_base = api_base
72+
openai.api_version = api_version
73+
74+
if api_type == "open_ai":
75+
if api_base or api_version:
76+
sys.stderr.write("You must not to set the GAI_API_BASE or GAI_API_VERSION for the open_ai GAI_API_TYPE")
77+
sys.exit(1)
78+
6279

6380
ascii_name = """
6481
__ _ _ _

perfcopilot/llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def get_chat_completion_args(messages, stream=False):
122122

123123
if openai.api_type == "azure":
124124
kwargs["deployment_id"] = get_model()
125-
elif openai.api_type == "openai":
125+
elif openai.api_type == "open_ai":
126126
kwargs["model"] = get_model()
127127
else:
128128
logging.error(f"Unknown API type: {openai.api_type}")

0 commit comments

Comments
 (0)