fixed bug when dalle thinks prompt is nsfw

This commit is contained in:
phixxy 2024-02-23 00:43:05 -08:00
parent 0bec58412d
commit 220f396668

View file

@ -200,7 +200,7 @@ class ChatGPT(commands.Cog):
else:
await ctx.send("Sorry you must be a premium member to use this command. (!donate)")
async def dalle_api_call(self, prompt, model="dall-e-2", quality="standard", size="1024x1024"):
async def dalle_api_call(self, prompt: str, model: str="dall-e-2", quality: str="standard", size: str="1024x1024") -> tuple:
data = {
"model": model,
"prompt": prompt,
@ -209,35 +209,35 @@ class ChatGPT(commands.Cog):
"n":1,
}
url = "https://api.openai.com/v1/images/generations"
try:
async with self.http_session.post(url, json=data, headers=self.headers) as resp:
response_data = await resp.json()
async with self.http_session.post(url, json=data, headers=self.headers) as resp:
response_data = await resp.json()
if resp.status == 200:
response = response_data['data'][0]['url']
return response
except Exception as error:
self.logger.exception("Error occurred in dalle")
return "Error occurred in dalle"
else:
response = response_data["error"]["message"]
self.logger.info(f"Error occurred in dalle: {resp.status} | {response}")
return (resp.status, response)
async def download_image(self, url, destination):
if url == "Error occurred in dalle":
return
async def download_image(self, url, destination) -> int:
async with self.http_session.get(url) as resp:
if resp.status == 200:
f = await aiofiles.open(destination, mode='wb')
await f.write(await resp.read())
await f.close()
return destination
return resp.status
async def generate_dalle_image(self, ctx, model, quality="standard", size="1024x1024"):
async def generate_dalle_image(self, ctx, model, quality="standard", size="1024x1024") -> None:
if ctx.author.get_role(self.premium_role):
prompt = ctx.message.content.split(" ", maxsplit=1)[1]
await ctx.send(f"Please be patient this may take some time! Generating: {prompt}.")
image_url = await self.dalle_api_call(prompt, model=model, quality=quality, size=size)
resp_status, resp = await self.dalle_api_call(prompt, model=model, quality=quality, size=size)
if resp_status != 200:
await ctx.send(f"Error generating image: {resp_status}: {resp}")
return
my_filename = str(time.time_ns()) + ".png"
image_filepath = f"{self.data_dir}dalle/{my_filename}"
await self.download_image(image_url, image_filepath)
await self.download_image(resp, image_filepath)
with open(image_filepath, "rb") as fh:
f = discord.File(fh, filename=image_filepath)
prompt = prompt.replace('\n',' ')