diff --git a/plugins/stable_diffusion.py b/plugins/stable_diffusion.py index 652d732..81a7043 100644 --- a/plugins/stable_diffusion.py +++ b/plugins/stable_diffusion.py @@ -35,20 +35,16 @@ class StableDiffusion(commands.Cog): 'Content-Type': 'application/json', 'Authorization': f'Bearer {os.getenv("openai.api_key")}', } - data = { "model": model, "messages": [{"role": "user", "content": topic}] } - url = "https://api.openai.com/v1/chat/completions" - try: async with self.bot.http_session.post(url, headers=headers, json=data) as resp: response_data = await resp.json() response = response_data['choices'][0]['message']['content'] return response - except Exception as error: return await self.handle_error(error) @@ -61,29 +57,22 @@ class StableDiffusion(commands.Cog): f.write(log_line) return error - - async def extract_key_value_pairs(self, input_str): - output_str = input_str - key_value_pairs = {} - tokens = input_str.split(', ') - for token in tokens: - if '=' in token: - key, value = token.split('=') - key_value_pairs[key] = value - output_str = output_str.replace(token+', ', '') # Remove the key-value pair from the output string - - return key_value_pairs, output_str - def get_kv_from_ctx(self, ctx): - prompt = ctx.message.content.split(" ", maxsplit=1)[1] - kv_strings = list(filter(lambda x: '=' in x,prompt.split(' '))) - key_value_pairs = dict(map(lambda a: a.replace(',','').split('='),kv_strings)) - return key_value_pairs + try: + prompt = ctx.message.content.split(" ", maxsplit=1)[1] + kv_strings = list(filter(lambda x: '=' in x,prompt.split(' '))) + key_value_pairs = dict(map(lambda a: a.replace(',','').split('='),kv_strings)) + return key_value_pairs + except: + return None def get_prompt_from_ctx(self, ctx): - prompt = ctx.message.content.split(" ", maxsplit=1)[1] - prompt = ' '.join(list(filter(lambda x: '=' not in x,prompt.split(' ')))) - return prompt + try: + prompt = ctx.message.content.split(" ", maxsplit=1)[1] + prompt = ' '.join(list(filter(lambda x: '=' not in x,prompt.split(' ')))) + return prompt + except: + return None async def my_open_img_file(self, path): img = Image.open(path) @@ -128,72 +117,25 @@ class StableDiffusion(commands.Cog): await self.handle_error(error) return "ERROR: CLIP may not be running. Could not look at image." return metadata + + async def generate_prompt(self): + choice1 = "Give me 11 keywords I can use to generate art using AI. They should all be related to one piece of art. Please only respond with the keywords and no other text. Be sure to use keywords that really describe what the art portrays. Keywords should be comma separated with no other text!" + choice2 = "Describe a creative scene, use only one sentence" + choice3 = "Give me comma seperated keywords describing an imaginary piece of art. Only return the keywords and no other text." + choice4 = "Describe a unique character and an environment in one sentence" + choice5 = "Describe a nonhuman character and an environment in one sentence" + prompt = random.choice([choice1,choice2,choice3,choice4,choice5]) + prompt = await self.answer_question(prompt) + if random.randint(0,9): + prompt = prompt.replace("abstract, ", "") + prompt = prompt.replace("AI, ", "") + if "." in prompt: + prompt = prompt.replace(".",",") + prompt = prompt + " masterpiece, studio quality" + else: + prompt = prompt + ", masterpiece, studio quality" + return prompt - @commands.command( - description="Draw", - help="Generates a picture using stable diffusion and gpt 3.5. It generates a list of 10 random artistic words and feeds them into stable diffusion. Usage: !draw (amount of pictures)", - brief="Generate a random image" - ) - async def draw(self, ctx): - url = self.stable_diffusion_url - if url == "disabled": - return - try: - if " " in ctx.message.content: - amount = ctx.message.content.split(" ", maxsplit=1)[1] - if int(amount) > 4: - await ctx.send("No, that's too many.") - return - else: - amount = 1 - await ctx.send("Please be patient this may take some time!") - - choice1 = "Give me 11 keywords I can use to generate art using AI. They should all be related to one piece of art. Please only respond with the keywords and no other text. Be sure to use keywords that really describe what the art portrays. Keywords should be comma separated with no other text!" - choice2 = "Describe a creative scene, use only one sentence" - choice3 = "Give me comma seperated keywords describing an imaginary piece of art. Only return the keywords and no other text." - choice4 = "Describe a unique character and an environment in one sentence" - choice5 = "Describe a nonhuman character and an environment in one sentence" - prompt = random.choice([choice1,choice2,choice3,choice4,choice5]) - prompt = await self.answer_question(prompt) - if random.randint(0,9): - prompt = prompt.replace("abstract, ", "") - prompt = prompt.replace("AI, ", "") - if "." in prompt: - prompt = prompt.replace(".",",") - prompt = prompt + " masterpiece, studio quality" - else: - prompt = prompt + ", masterpiece, studio quality" - negative_prompt = "easynegative verybadimagenegative_v1.3" - payload = {"prompt": prompt,"steps": 25, "negative_prompt": negative_prompt,"batch_size": amount} - try: - async with self.bot.http_session.post(url=f'{url}/sdapi/v1/txt2img', json=payload) as resp: - r = await resp.json() - for i in r['images']: - image = Image.open(io.BytesIO(base64.b64decode(i.split(",",1)[0]))) - png_payload = {"image": "data:image/png;base64," + i} - async with self.bot.http_session.post(url=f'{url}/sdapi/v1/png-info', json=png_payload) as resp2: - response2 = await resp2.json() - pnginfo = PngImagePlugin.PngInfo() - pnginfo.add_text("parameters", response2.get("info")) - try: - if ctx.channel.is_nsfw(): - folder = self.working_dir + "nsfw/" - else: - folder = self.working_dir + "sfw/" - except: - folder = self.working_dir - my_filename = folder + str(time.time_ns()) + ".png" - image.save(my_filename, pnginfo=pnginfo) - with open(my_filename, "rb") as fh: - f = discord.File(fh, filename=my_filename) - await ctx.send(prompt, file=f) - except Exception as error: - await self.handle_error(error) - await ctx.send("My image generation service may not be running.") - except Exception as error: - await self.handle_error(error) - await ctx.send('Did you mean to use !imagine?. Usage: !draw (number)') - await self.bot.http_session.close() @commands.command( description="Change Model", @@ -264,8 +206,10 @@ class StableDiffusion(commands.Cog): return else: url=f"{url}/sdapi/v1/txt2img" - prompt = ctx.message.content.split(" ", maxsplit=1)[1] - key_value_pairs, prompt = await self.extract_key_value_pairs(prompt) + prompt = self.get_prompt_from_ctx(ctx) + key_value_pairs = self.get_kv_from_ctx(ctx) + if prompt == None: + prompt = await self.generate_prompt() try: neg_prompt_file = f"{self.db_dir}negative_prompt.txt" with open(neg_prompt_file, 'r') as f: @@ -284,7 +228,8 @@ class StableDiffusion(commands.Cog): headers = { 'Content-Type': 'application/json' } - payload.update(key_value_pairs) + if key_value_pairs: + payload.update(key_value_pairs) try: async with self.bot.http_session.post(url, headers=headers, json=payload) as resp: r = await resp.json() @@ -309,9 +254,9 @@ class StableDiffusion(commands.Cog): if ctx.channel.is_nsfw(): folder = self.working_dir + "nsfw/" else: - folder = self.working_dir + "sfw/" + folder = self.working_dir + "sfw" except: - folder = self.working_dir + str(ctx.author.id) + '/' + folder = self.working_dir my_filename = folder + str(time.time_ns()) + ".png" image.save(my_filename, pnginfo=pnginfo) @@ -394,7 +339,8 @@ class StableDiffusion(commands.Cog): await self.handle_error(error) print("Couldn't find image.") return - key_value_pairs, prompt = await self.extract_key_value_pairs(prompt) + prompt = self.get_prompt_from_ctx(ctx) + key_value_pairs = self.get_kv_from_ctx(ctx) try: async with self.bot.http_session.get(file_url) as response: imageName = self.working_dir + str(time.time_ns()) + ".png"