moving all plugins to extensions folder
This commit is contained in:
parent
37c77680b4
commit
65117c42e8
1 changed files with 41 additions and 95 deletions
|
|
@ -35,20 +35,16 @@ class StableDiffusion(commands.Cog):
|
|||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {os.getenv("openai.api_key")}',
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": topic}]
|
||||
}
|
||||
|
||||
url = "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
try:
|
||||
async with self.bot.http_session.post(url, headers=headers, json=data) as resp:
|
||||
response_data = await resp.json()
|
||||
response = response_data['choices'][0]['message']['content']
|
||||
return response
|
||||
|
||||
except Exception as error:
|
||||
return await self.handle_error(error)
|
||||
|
||||
|
|
@ -61,29 +57,22 @@ class StableDiffusion(commands.Cog):
|
|||
f.write(log_line)
|
||||
return error
|
||||
|
||||
|
||||
async def extract_key_value_pairs(self, input_str):
|
||||
output_str = input_str
|
||||
key_value_pairs = {}
|
||||
tokens = input_str.split(', ')
|
||||
for token in tokens:
|
||||
if '=' in token:
|
||||
key, value = token.split('=')
|
||||
key_value_pairs[key] = value
|
||||
output_str = output_str.replace(token+', ', '') # Remove the key-value pair from the output string
|
||||
|
||||
return key_value_pairs, output_str
|
||||
|
||||
def get_kv_from_ctx(self, ctx):
|
||||
try:
|
||||
prompt = ctx.message.content.split(" ", maxsplit=1)[1]
|
||||
kv_strings = list(filter(lambda x: '=' in x,prompt.split(' ')))
|
||||
key_value_pairs = dict(map(lambda a: a.replace(',','').split('='),kv_strings))
|
||||
return key_value_pairs
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_prompt_from_ctx(self, ctx):
|
||||
try:
|
||||
prompt = ctx.message.content.split(" ", maxsplit=1)[1]
|
||||
prompt = ' '.join(list(filter(lambda x: '=' not in x,prompt.split(' '))))
|
||||
return prompt
|
||||
except:
|
||||
return None
|
||||
|
||||
async def my_open_img_file(self, path):
|
||||
img = Image.open(path)
|
||||
|
|
@ -129,25 +118,7 @@ class StableDiffusion(commands.Cog):
|
|||
return "ERROR: CLIP may not be running. Could not look at image."
|
||||
return metadata
|
||||
|
||||
@commands.command(
|
||||
description="Draw",
|
||||
help="Generates a picture using stable diffusion and gpt 3.5. It generates a list of 10 random artistic words and feeds them into stable diffusion. Usage: !draw (amount of pictures)",
|
||||
brief="Generate a random image"
|
||||
)
|
||||
async def draw(self, ctx):
|
||||
url = self.stable_diffusion_url
|
||||
if url == "disabled":
|
||||
return
|
||||
try:
|
||||
if " " in ctx.message.content:
|
||||
amount = ctx.message.content.split(" ", maxsplit=1)[1]
|
||||
if int(amount) > 4:
|
||||
await ctx.send("No, that's too many.")
|
||||
return
|
||||
else:
|
||||
amount = 1
|
||||
await ctx.send("Please be patient this may take some time!")
|
||||
|
||||
async def generate_prompt(self):
|
||||
choice1 = "Give me 11 keywords I can use to generate art using AI. They should all be related to one piece of art. Please only respond with the keywords and no other text. Be sure to use keywords that really describe what the art portrays. Keywords should be comma separated with no other text!"
|
||||
choice2 = "Describe a creative scene, use only one sentence"
|
||||
choice3 = "Give me comma seperated keywords describing an imaginary piece of art. Only return the keywords and no other text."
|
||||
|
|
@ -163,37 +134,8 @@ class StableDiffusion(commands.Cog):
|
|||
prompt = prompt + " masterpiece, studio quality"
|
||||
else:
|
||||
prompt = prompt + ", masterpiece, studio quality"
|
||||
negative_prompt = "easynegative verybadimagenegative_v1.3"
|
||||
payload = {"prompt": prompt,"steps": 25, "negative_prompt": negative_prompt,"batch_size": amount}
|
||||
try:
|
||||
async with self.bot.http_session.post(url=f'{url}/sdapi/v1/txt2img', json=payload) as resp:
|
||||
r = await resp.json()
|
||||
for i in r['images']:
|
||||
image = Image.open(io.BytesIO(base64.b64decode(i.split(",",1)[0])))
|
||||
png_payload = {"image": "data:image/png;base64," + i}
|
||||
async with self.bot.http_session.post(url=f'{url}/sdapi/v1/png-info', json=png_payload) as resp2:
|
||||
response2 = await resp2.json()
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
pnginfo.add_text("parameters", response2.get("info"))
|
||||
try:
|
||||
if ctx.channel.is_nsfw():
|
||||
folder = self.working_dir + "nsfw/"
|
||||
else:
|
||||
folder = self.working_dir + "sfw/"
|
||||
except:
|
||||
folder = self.working_dir
|
||||
my_filename = folder + str(time.time_ns()) + ".png"
|
||||
image.save(my_filename, pnginfo=pnginfo)
|
||||
with open(my_filename, "rb") as fh:
|
||||
f = discord.File(fh, filename=my_filename)
|
||||
await ctx.send(prompt, file=f)
|
||||
except Exception as error:
|
||||
await self.handle_error(error)
|
||||
await ctx.send("My image generation service may not be running.")
|
||||
except Exception as error:
|
||||
await self.handle_error(error)
|
||||
await ctx.send('Did you mean to use !imagine?. Usage: !draw (number)')
|
||||
await self.bot.http_session.close()
|
||||
return prompt
|
||||
|
||||
|
||||
@commands.command(
|
||||
description="Change Model",
|
||||
|
|
@ -264,8 +206,10 @@ class StableDiffusion(commands.Cog):
|
|||
return
|
||||
else:
|
||||
url=f"{url}/sdapi/v1/txt2img"
|
||||
prompt = ctx.message.content.split(" ", maxsplit=1)[1]
|
||||
key_value_pairs, prompt = await self.extract_key_value_pairs(prompt)
|
||||
prompt = self.get_prompt_from_ctx(ctx)
|
||||
key_value_pairs = self.get_kv_from_ctx(ctx)
|
||||
if prompt == None:
|
||||
prompt = await self.generate_prompt()
|
||||
try:
|
||||
neg_prompt_file = f"{self.db_dir}negative_prompt.txt"
|
||||
with open(neg_prompt_file, 'r') as f:
|
||||
|
|
@ -284,6 +228,7 @@ class StableDiffusion(commands.Cog):
|
|||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
if key_value_pairs:
|
||||
payload.update(key_value_pairs)
|
||||
try:
|
||||
async with self.bot.http_session.post(url, headers=headers, json=payload) as resp:
|
||||
|
|
@ -309,9 +254,9 @@ class StableDiffusion(commands.Cog):
|
|||
if ctx.channel.is_nsfw():
|
||||
folder = self.working_dir + "nsfw/"
|
||||
else:
|
||||
folder = self.working_dir + "sfw/"
|
||||
folder = self.working_dir + "sfw"
|
||||
except:
|
||||
folder = self.working_dir + str(ctx.author.id) + '/'
|
||||
folder = self.working_dir
|
||||
my_filename = folder + str(time.time_ns()) + ".png"
|
||||
image.save(my_filename, pnginfo=pnginfo)
|
||||
|
||||
|
|
@ -394,7 +339,8 @@ class StableDiffusion(commands.Cog):
|
|||
await self.handle_error(error)
|
||||
print("Couldn't find image.")
|
||||
return
|
||||
key_value_pairs, prompt = await self.extract_key_value_pairs(prompt)
|
||||
prompt = self.get_prompt_from_ctx(ctx)
|
||||
key_value_pairs = self.get_kv_from_ctx(ctx)
|
||||
try:
|
||||
async with self.bot.http_session.get(file_url) as response:
|
||||
imageName = self.working_dir + str(time.time_ns()) + ".png"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue