Now save generated stable diffusion images in tmp/sfw or nsfw

Now save generated memes in tmp/meme/
This commit is contained in:
phixxy 2024-01-17 19:25:00 -08:00
parent 8aeb0142a6
commit b2fb09e681
3 changed files with 22 additions and 9 deletions

View file

@ -107,7 +107,7 @@ async def meme(ctx):
#------------------------------------Saving Image Using Aiohttp---------------------------------# #------------------------------------Saving Image Using Aiohttp---------------------------------#
filename = memepics[id-1]['name'] filename = memepics[id-1]['name']
async with http_session.get(image_link) as response: async with http_session.get(image_link) as response:
folder = "tmp/" folder = "tmp/meme/"
filename = folder + topic + str(len(os.listdir(folder))) + ".jpg" filename = folder + topic + str(len(os.listdir(folder))) + ".jpg"
with open(filename, "wb") as file: with open(filename, "wb") as file:

View file

@ -150,11 +150,15 @@ async def draw(ctx):
response2 = await resp2.json() response2 = await resp2.json()
pnginfo = PngImagePlugin.PngInfo() pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", response2.get("info")) pnginfo.add_text("parameters", response2.get("info"))
my_filename = "tmp/" + str(len(os.listdir("tmp/"))) + ".png" try:
if ctx.channel.is_nsfw():
folder = "tmp/nsfw/"
else:
folder = "tmp/sfw/"
except:
folder = "tmp/"
my_filename = folder + str(time.time_ns()) + ".png"
image.save(my_filename, pnginfo=pnginfo) image.save(my_filename, pnginfo=pnginfo)
'''channel_vars = await get_channel_config(ctx.channel.id)
if channel_vars["ftp_enabled"]:
await upload_ftp_ai_images(my_filename, prompt)'''
with open(my_filename, "rb") as fh: with open(my_filename, "rb") as fh:
f = discord.File(fh, filename=my_filename) f = discord.File(fh, filename=my_filename)
await ctx.send(file=f) await ctx.send(file=f)
@ -266,6 +270,10 @@ async def imagine(ctx):
for i in r['images']: for i in r['images']:
if not os.path.isdir("users/" + str(ctx.author.id)): if not os.path.isdir("users/" + str(ctx.author.id)):
os.makedirs("users/" + str(ctx.author.id)) os.makedirs("users/" + str(ctx.author.id))
if not os.path.isdir("users/" + str(ctx.author.id) + '/nsfw/'):
os.makedirs("users/" + str(ctx.author.id) + '/nsfw/')
if not os.path.isdir("users/" + str(ctx.author.id) + '/sfw/'):
os.makedirs("users/" + str(ctx.author.id) + '/sfw/')
image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0]))) image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0])))
png_payload = {"image": "data:image/png;base64," + i} png_payload = {"image": "data:image/png;base64," + i}
@ -280,8 +288,14 @@ async def imagine(ctx):
pnginfo = PngImagePlugin.PngInfo() pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", response2.get("info")) pnginfo.add_text("parameters", response2.get("info"))
try:
my_filename = "users/" + str(ctx.author.id) + '/' + str(len(os.listdir("users/" + str(ctx.author.id) + '/'))) + ".png" if ctx.channel.is_nsfw():
folder = "users/" + str(ctx.author.id) + '/nsfw/'
else:
folder = "users/" + str(ctx.author.id) + '/sfw/'
except:
folder = "users/" + str(ctx.author.id) + '/'
my_filename = folder + str(time.time_ns()) + ".png"
image.save(my_filename, pnginfo=pnginfo) image.save(my_filename, pnginfo=pnginfo)
#channel_vars = await get_channel_config(ctx.channel.id) #channel_vars = await get_channel_config(ctx.channel.id)

View file

@ -237,7 +237,7 @@ async def chat_response(ctx, channel_vars, chat_history_string):
await handle_error(error) await handle_error(error)
async def folder_setup(): async def folder_setup():
folder_names = ["tmp", "channels", "users", "channels/config", "channels/logs", "databases", "databases/currency", "databases/currency/players"] folder_names = ["tmp", "tmp/sfw", "tmp/nsfw", "tmp/meme/", "channels", "users", "channels/config", "channels/logs", "databases", "databases/currency", "databases/currency/players"]
for folder_name in folder_names: for folder_name in folder_names:
if not os.path.exists(folder_name): if not os.path.exists(folder_name):
os.mkdir(folder_name) os.mkdir(folder_name)
@ -1038,7 +1038,6 @@ async def pkmn_msg(discord_id):
@bot.event @bot.event
async def on_message(ctx): async def on_message(ctx):
#log stuff #log stuff
logfile = "channels/logs/{0}.log".format(str(ctx.channel.id)) logfile = "channels/logs/{0}.log".format(str(ctx.channel.id))
channel_vars = await get_channel_config(ctx.channel.id) channel_vars = await get_channel_config(ctx.channel.id)