added type hinting and docstrings
This commit is contained in:
parent
40b6aaab38
commit
b023724e09
1 changed files with 197 additions and 60 deletions
|
|
@ -20,7 +20,7 @@ class StableDiffusion(commands.Cog):
|
||||||
self.default_neg_prompt = "easynegative, badhandv4, verybadimagenegative_v1.3"
|
self.default_neg_prompt = "easynegative, badhandv4, verybadimagenegative_v1.3"
|
||||||
self.folder_setup()
|
self.folder_setup()
|
||||||
|
|
||||||
def folder_setup(self):
|
def folder_setup(self) -> None:
|
||||||
try:
|
try:
|
||||||
if not os.path.exists(self.working_dir):
|
if not os.path.exists(self.working_dir):
|
||||||
os.mkdir(self.working_dir)
|
os.mkdir(self.working_dir)
|
||||||
|
|
@ -31,42 +31,91 @@ class StableDiffusion(commands.Cog):
|
||||||
except:
|
except:
|
||||||
self.bot.logger.exception("StableDiffusion failed to make directories")
|
self.bot.logger.exception("StableDiffusion failed to make directories")
|
||||||
|
|
||||||
async def answer_question(self, topic, model="gpt-3.5-turbo"): # Only needed for draw command
|
"""
|
||||||
headers = {
|
answer_question asynchronously calls the OpenAI API to get a response for the given question/topic using the specified model.
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Authorization': f'Bearer {os.getenv("openai.api_key")}',
|
|
||||||
}
|
|
||||||
data = {
|
|
||||||
"model": model,
|
|
||||||
"messages": [{"role": "user", "content": topic}]
|
|
||||||
}
|
|
||||||
url = "https://api.openai.com/v1/chat/completions"
|
|
||||||
try:
|
|
||||||
async with self.bot.http_session.post(url, headers=headers, json=data) as resp:
|
|
||||||
response_data = await resp.json()
|
|
||||||
response = response_data['choices'][0]['message']['content']
|
|
||||||
return response
|
|
||||||
except Exception as error:
|
|
||||||
return "Error in answer question in stable_diffusion"
|
|
||||||
|
|
||||||
def get_kv_from_ctx(self, ctx):
|
Parameters:
|
||||||
try:
|
- topic (str): The question or topic to get a response for.
|
||||||
prompt = ctx.message.content.split(" ", maxsplit=1)[1]
|
- model (str): The OpenAI model to use. Defaults to "gpt-3.5-turbo".
|
||||||
kv_strings = list(filter(lambda x: '=' in x,prompt.split(' ')))
|
|
||||||
key_value_pairs = dict(map(lambda a: a.replace(',','').split('='),kv_strings))
|
|
||||||
return key_value_pairs
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_prompt_from_ctx(self, ctx):
|
Returns:
|
||||||
try:
|
- str: The response from the OpenAI API.
|
||||||
prompt = ctx.message.content.split(" ", maxsplit=1)[1]
|
|
||||||
prompt = ' '.join(list(filter(lambda x: '=' not in x,prompt.split(' '))))
|
|
||||||
return prompt
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def my_open_img_file(self, path):
|
Raises:
|
||||||
|
- Exception: If an error occurs when calling the API.
|
||||||
|
"""
|
||||||
|
async def answer_question(self, topic: str, model: str="gpt-3.5-turbo") -> str: # Only needed for draw command
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': f'Bearer {os.getenv("openai.api_key")}',
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [{"role": "user", "content": topic}]
|
||||||
|
}
|
||||||
|
url = "https://api.openai.com/v1/chat/completions"
|
||||||
|
try:
|
||||||
|
async with self.bot.http_session.post(url, headers=headers, json=data) as resp:
|
||||||
|
response_data = await resp.json()
|
||||||
|
response = response_data['choices'][0]['message']['content']
|
||||||
|
return response
|
||||||
|
except:
|
||||||
|
return "Error in answer question in stable_diffusion"
|
||||||
|
|
||||||
|
"""
|
||||||
|
Gets key-value pairs from a context message.
|
||||||
|
|
||||||
|
Parses the message content to extract key-value pairs separated by '='.
|
||||||
|
Returns a dict of key-value pairs.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
ctx (commands.Context): The context object containing the message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dict of key-value pairs extracted from the message.
|
||||||
|
"""
|
||||||
|
def get_kv_from_ctx(self, ctx: commands.Context) -> dict:
|
||||||
|
try:
|
||||||
|
prompt = ctx.message.content.split(" ", maxsplit=1)[1]
|
||||||
|
kv_strings = list(filter(lambda x: '=' in x,prompt.split(' ')))
|
||||||
|
key_value_pairs = dict(map(lambda a: a.replace(',','').split('='),kv_strings))
|
||||||
|
return key_value_pairs
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
"""
|
||||||
|
Gets prompt from context message by splitting on spaces and removing key-value pairs.
|
||||||
|
|
||||||
|
Splits the context message content on spaces, takes the second part after
|
||||||
|
the command name. Removes any key-value pairs separated by '=' from the prompt.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
ctx (commands.Context): The context object containing the message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The prompt text extracted from the context message.
|
||||||
|
"""
|
||||||
|
def get_prompt_from_ctx(self, ctx: commands.Context) -> str:
|
||||||
|
try:
|
||||||
|
prompt = ctx.message.content.split(" ", maxsplit=1)[1]
|
||||||
|
prompt = ' '.join(list(filter(lambda x: '=' not in x,prompt.split(' '))))
|
||||||
|
return prompt
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
"""
|
||||||
|
Encodes an image file from the given path into a base64 string.
|
||||||
|
|
||||||
|
Opens the image file, encodes it into a base64 string, closes the image,
|
||||||
|
and returns the encoded string.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
path (str): The path to the image file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The base64 encoded image data.
|
||||||
|
"""
|
||||||
|
async def my_open_img_file(self, path: str) -> str:
|
||||||
img = Image.open(path)
|
img = Image.open(path)
|
||||||
encoded = ""
|
encoded = ""
|
||||||
with io.BytesIO() as output:
|
with io.BytesIO() as output:
|
||||||
|
|
@ -76,12 +125,27 @@ class StableDiffusion(commands.Cog):
|
||||||
img.close()
|
img.close()
|
||||||
return encoded
|
return encoded
|
||||||
|
|
||||||
async def look_at(self, ctx, look=False):
|
"""
|
||||||
|
Looks at an image attachment in the given context and returns metadata about it.
|
||||||
|
|
||||||
|
If the look parameter is True, this iterates through the attachments
|
||||||
|
in the context checking for image files. If an image is found, it is
|
||||||
|
downloaded and encoded to base64. The image is then sent to the
|
||||||
|
Stable Diffusion API to generate a caption, which is returned in the metadata.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
ctx (commands.Context): The context containing the command and attachments.
|
||||||
|
look (bool): Whether to look at images and generate metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The metadata string containing any generated image captions.
|
||||||
|
"""
|
||||||
|
async def look_at(self, ctx: commands.Context, look: bool=False) -> str:
|
||||||
metadata = ""
|
metadata = ""
|
||||||
if look:
|
if look:
|
||||||
url = self.stable_diffusion_url
|
url = self.stable_diffusion_url
|
||||||
if url == "disabled":
|
if url == "disabled":
|
||||||
return
|
return "Stable Diffusion is disabled, could not look at image"
|
||||||
for attachment in ctx.attachments:
|
for attachment in ctx.attachments:
|
||||||
if attachment.url.endswith(('.jpg', '.png')):
|
if attachment.url.endswith(('.jpg', '.png')):
|
||||||
self.bot.logger.debug("image seen")
|
self.bot.logger.debug("image seen")
|
||||||
|
|
@ -110,7 +174,15 @@ class StableDiffusion(commands.Cog):
|
||||||
return "ERROR: CLIP may not be running. Could not look at image."
|
return "ERROR: CLIP may not be running. Could not look at image."
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
async def generate_prompt(self):
|
"""
|
||||||
|
Generates a prompt for use with an AI art generator.
|
||||||
|
|
||||||
|
Combines randomly selected question prompts with an AI assistant's response,
|
||||||
|
then optionally removes abstract keywords and adds modifiers like "masterpiece"
|
||||||
|
to create a prompt that describes a detailed scene or character for the AI art
|
||||||
|
generator.
|
||||||
|
"""
|
||||||
|
async def generate_prompt(self) -> str:
|
||||||
choice1 = "Give me 11 keywords I can use to generate art using AI. They should all be related to one piece of art. Please only respond with the keywords and no other text. Be sure to use keywords that really describe what the art portrays. Keywords should be comma separated with no other text!"
|
choice1 = "Give me 11 keywords I can use to generate art using AI. They should all be related to one piece of art. Please only respond with the keywords and no other text. Be sure to use keywords that really describe what the art portrays. Keywords should be comma separated with no other text!"
|
||||||
choice2 = "Describe a creative scene, use only one sentence"
|
choice2 = "Describe a creative scene, use only one sentence"
|
||||||
choice3 = "Give me comma seperated keywords describing an imaginary piece of art. Only return the keywords and no other text."
|
choice3 = "Give me comma seperated keywords describing an imaginary piece of art. Only return the keywords and no other text."
|
||||||
|
|
@ -128,13 +200,12 @@ class StableDiffusion(commands.Cog):
|
||||||
prompt = prompt + ", masterpiece, studio quality"
|
prompt = prompt + ", masterpiece, studio quality"
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
@commands.command(
|
@commands.command(
|
||||||
description="Change Model",
|
description="Change Model",
|
||||||
help="Choose from a list of stable diffusion models.",
|
help="Changes the Stable Diffusion model used by the bot.",
|
||||||
brief="Change stable diffusion model"
|
brief="Change stable diffusion model"
|
||||||
)
|
)
|
||||||
async def change_model(self, ctx, model_choice='0'): # Needs to be a configurable list of models
|
async def change_model(self, ctx: commands.Context, model_choice: str='0') -> None: # Needs to be a configurable list of models
|
||||||
model_choices = {
|
model_choices = {
|
||||||
'1': ("deliberate_v2.safetensors [9aba26abdf]", "DeliberateV2"),
|
'1': ("deliberate_v2.safetensors [9aba26abdf]", "DeliberateV2"),
|
||||||
'2': ("flat2DAnimerge_v30.safetensors [5dd56bfa12]", "Flat2D"),
|
'2': ("flat2DAnimerge_v30.safetensors [5dd56bfa12]", "Flat2D"),
|
||||||
|
|
@ -170,10 +241,10 @@ class StableDiffusion(commands.Cog):
|
||||||
|
|
||||||
@commands.command(
|
@commands.command(
|
||||||
description="Lora",
|
description="Lora",
|
||||||
help="List the stable diffusion loras.",
|
help="Lists available Stable Diffusion loras and their trigger words.",
|
||||||
brief="List the stable diffusion loras"
|
brief="List the stable diffusion loras"
|
||||||
)
|
)
|
||||||
async def lora(self, ctx):
|
async def lora(self, ctx: commands.Context) -> None:
|
||||||
lora_choices = {
|
lora_choices = {
|
||||||
'0': ("Lora Name", "Trigger Words"),
|
'0': ("Lora Name", "Trigger Words"),
|
||||||
'1': ("<lora:rebecca:1>", "rebecca (cyberpunk)"),
|
'1': ("<lora:rebecca:1>", "rebecca (cyberpunk)"),
|
||||||
|
|
@ -186,7 +257,18 @@ class StableDiffusion(commands.Cog):
|
||||||
output += lora_options
|
output += lora_options
|
||||||
await ctx.send(output)
|
await ctx.send(output)
|
||||||
|
|
||||||
async def get_image_from_ctx(self, ctx):
|
"""
|
||||||
|
Gets the image URL from a Discord context.
|
||||||
|
|
||||||
|
Checks for an image URL in attachments or message content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Discord context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Image URL or None
|
||||||
|
"""
|
||||||
|
async def get_image_from_ctx(self, ctx: commands.Context) -> str:
|
||||||
if ctx.message.attachments:
|
if ctx.message.attachments:
|
||||||
file_url = ctx.message.attachments[0].url
|
file_url = ctx.message.attachments[0].url
|
||||||
return file_url
|
return file_url
|
||||||
|
|
@ -197,7 +279,17 @@ class StableDiffusion(commands.Cog):
|
||||||
self.bot.logger.info("Couldn't find image.")
|
self.bot.logger.info("Couldn't find image.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def txt2img(self, ctx, prompt):
|
"""
|
||||||
|
Sends an image generation request to the Stable Diffusion API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: The Discord context.
|
||||||
|
prompt: The text prompt to generate the image from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None. Sends the generated image back to the user.
|
||||||
|
"""
|
||||||
|
async def txt2img(self, ctx: commands.Context, prompt: str) -> None:
|
||||||
url = f"{self.stable_diffusion_url}/sdapi/v1/txt2img"
|
url = f"{self.stable_diffusion_url}/sdapi/v1/txt2img"
|
||||||
key_value_pairs = self.get_kv_from_ctx(ctx)
|
key_value_pairs = self.get_kv_from_ctx(ctx)
|
||||||
headers = {'Content-Type': 'application/json'}
|
headers = {'Content-Type': 'application/json'}
|
||||||
|
|
@ -226,7 +318,16 @@ class StableDiffusion(commands.Cog):
|
||||||
|
|
||||||
await self.send_generated_image(ctx, r['images'], prompt)
|
await self.send_generated_image(ctx, r['images'], prompt)
|
||||||
|
|
||||||
async def save_image(self, url):
|
"""
|
||||||
|
Saves an image from a URL to disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL of the image to save.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The path to the saved image file.
|
||||||
|
"""
|
||||||
|
async def save_image(self, url: str) -> str:
|
||||||
async with self.bot.http_session.get(url) as response:
|
async with self.bot.http_session.get(url) as response:
|
||||||
image_name = self.working_dir + str(time.time_ns()) + ".png"
|
image_name = self.working_dir + str(time.time_ns()) + ".png"
|
||||||
with open(image_name, 'wb') as out_file:
|
with open(image_name, 'wb') as out_file:
|
||||||
|
|
@ -238,7 +339,22 @@ class StableDiffusion(commands.Cog):
|
||||||
out_file.write(chunk)
|
out_file.write(chunk)
|
||||||
return image_name
|
return image_name
|
||||||
|
|
||||||
async def img2img(self, ctx, prompt):
|
"""
|
||||||
|
Generates an image by modifying an initial image based on an optional
|
||||||
|
text prompt.
|
||||||
|
|
||||||
|
Sends a request to the Stable Diffusion API to modify the initial image
|
||||||
|
according to the given prompt. The modified image is then sent back to
|
||||||
|
the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: The Discord context.
|
||||||
|
prompt: The text prompt to guide image modification.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None. Sends the generated image back to the user.
|
||||||
|
"""
|
||||||
|
async def img2img(self, ctx: commands.Context, prompt: str) -> None:
|
||||||
url = f"{self.stable_diffusion_url}/sdapi/v1/img2img"
|
url = f"{self.stable_diffusion_url}/sdapi/v1/img2img"
|
||||||
file_url = await self.get_image_from_ctx(ctx)
|
file_url = await self.get_image_from_ctx(ctx)
|
||||||
image_name = await self.save_image(file_url)
|
image_name = await self.save_image(file_url)
|
||||||
|
|
@ -273,7 +389,20 @@ class StableDiffusion(commands.Cog):
|
||||||
await self.send_generated_image(ctx, r['images'], prompt)
|
await self.send_generated_image(ctx, r['images'], prompt)
|
||||||
|
|
||||||
|
|
||||||
async def send_generated_image(self, ctx, images, prompt):
|
"""
|
||||||
|
Sends a generated image file to Discord along with the prompt.
|
||||||
|
|
||||||
|
Saves the image file locally first, logs the prompt and filename,
|
||||||
|
then sends the image and prompt to Discord.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: The Discord context.
|
||||||
|
images: List of base64 encoded image data.
|
||||||
|
prompt: The text prompt used to generate the image.
|
||||||
|
|
||||||
|
Returns: None.
|
||||||
|
"""
|
||||||
|
async def send_generated_image(self, ctx: commands.Context, images: dict, prompt: str) -> None:
|
||||||
for i in images:
|
for i in images:
|
||||||
image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0])))
|
image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0])))
|
||||||
try:
|
try:
|
||||||
|
|
@ -298,7 +427,16 @@ class StableDiffusion(commands.Cog):
|
||||||
await ctx.send(f'Generated by: {ctx.author.name}\nPrompt: {prompt}', file=f)
|
await ctx.send(f'Generated by: {ctx.author.name}\nPrompt: {prompt}', file=f)
|
||||||
|
|
||||||
|
|
||||||
def get_negative_prompt(self):
|
"""
|
||||||
|
Gets a negative prompt text from a file.
|
||||||
|
|
||||||
|
If the file does not exist, it will be created with
|
||||||
|
default negative prompt text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The negative prompt text loaded from the file.
|
||||||
|
"""
|
||||||
|
def get_negative_prompt(self) -> str:
|
||||||
try:
|
try:
|
||||||
neg_prompt_file = f"{self.data_dir}negative_prompt.txt"
|
neg_prompt_file = f"{self.data_dir}negative_prompt.txt"
|
||||||
with open(neg_prompt_file, 'r') as f:
|
with open(neg_prompt_file, 'r') as f:
|
||||||
|
|
@ -310,13 +448,12 @@ class StableDiffusion(commands.Cog):
|
||||||
negative_prompt = self.default_neg_prompt
|
negative_prompt = self.default_neg_prompt
|
||||||
return negative_prompt
|
return negative_prompt
|
||||||
|
|
||||||
|
|
||||||
@commands.command(
|
@commands.command(
|
||||||
description="Imagine",
|
description="Imagine",
|
||||||
help="Generate an image using stable diffusion. You can add keyword arguments to your prompt and they will be treated as stable diffusion options. Usage !imagine (topic)",
|
help="Generate an image using stable diffusion. You can add keyword arguments to your prompt and they will be treated as stable diffusion options. Usage !imagine (topic)",
|
||||||
brief="Generate an image"
|
brief="Generate an image"
|
||||||
)
|
)
|
||||||
async def imagine(self, ctx):
|
async def imagine(self, ctx: commands.Context) -> None:
|
||||||
url = self.stable_diffusion_url
|
url = self.stable_diffusion_url
|
||||||
if url == "disabled":
|
if url == "disabled":
|
||||||
await ctx.send("Command is currently disabled")
|
await ctx.send("Command is currently disabled")
|
||||||
|
|
@ -335,7 +472,7 @@ class StableDiffusion(commands.Cog):
|
||||||
help="Get better understanding of what the bot \"sees\" when you post an image! (Runs it through CLIP) Usage !describe (image link)",
|
help="Get better understanding of what the bot \"sees\" when you post an image! (Runs it through CLIP) Usage !describe (image link)",
|
||||||
brief="Describe image"
|
brief="Describe image"
|
||||||
)
|
)
|
||||||
async def describe(self, ctx):
|
async def describe(self, ctx: commands.Context) -> None:
|
||||||
url = self.stable_diffusion_url
|
url = self.stable_diffusion_url
|
||||||
if url == "disabled":
|
if url == "disabled":
|
||||||
await ctx.send("Command is currently disabled")
|
await ctx.send("Command is currently disabled")
|
||||||
|
|
@ -359,7 +496,7 @@ class StableDiffusion(commands.Cog):
|
||||||
help="Reimagine an image as something else. One example is reimagining a picture as anime. This command can be hard to use. \nUsage: !reimagine (image link) (topic)\nExample: !reimagine (image link) anime",
|
help="Reimagine an image as something else. One example is reimagining a picture as anime. This command can be hard to use. \nUsage: !reimagine (image link) (topic)\nExample: !reimagine (image link) anime",
|
||||||
brief="Reimagine an image"
|
brief="Reimagine an image"
|
||||||
)
|
)
|
||||||
async def reimagine(self, ctx):
|
async def reimagine(self, ctx: commands.Context) -> None:
|
||||||
url = self.stable_diffusion_url
|
url = self.stable_diffusion_url
|
||||||
if url == "disabled":
|
if url == "disabled":
|
||||||
await ctx.send("Command is currently disabled")
|
await ctx.send("Command is currently disabled")
|
||||||
|
|
@ -375,7 +512,7 @@ class StableDiffusion(commands.Cog):
|
||||||
help="Changes the negative prompt for imagine across all channels",
|
help="Changes the negative prompt for imagine across all channels",
|
||||||
brief="Change the negative prompt for imagine"
|
brief="Change the negative prompt for imagine"
|
||||||
)
|
)
|
||||||
async def negative_prompt(self, ctx, *args):
|
async def negative_prompt(self, ctx: commands.Context, *args: list) -> None:
|
||||||
message = ' '.join(args)
|
message = ' '.join(args)
|
||||||
if not message:
|
if not message:
|
||||||
message = self.default_neg_prompt
|
message = self.default_neg_prompt
|
||||||
|
|
@ -385,6 +522,6 @@ class StableDiffusion(commands.Cog):
|
||||||
await ctx.send("Changed negative prompt to " + message)
|
await ctx.send("Changed negative prompt to " + message)
|
||||||
|
|
||||||
|
|
||||||
async def setup(bot):
|
async def setup(bot: commands.Bot):
|
||||||
await bot.add_cog(StableDiffusion(bot))
|
await bot.add_cog(StableDiffusion(bot))
|
||||||
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue