matrix-nsfwbot/nsfwbot.py

250 lines
10 KiB
Python

import os
import time
from bs4 import BeautifulSoup
from collections import defaultdict
from maubot import Plugin
from nudenet import NudeDetector
from uuid import uuid4
from typing import List, Type, Tuple
from asyncio import Semaphore
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
from mautrix.types import (
MessageEvent, MessageType, RoomAlias, RoomID, EventID, TextMessageEventContent, MediaMessageEventContent
)
from mautrix.errors import MBadJSON, MForbidden
from maubot.handlers import command, event
# Initialize NudeDetector
detector = NudeDetector()
min_score = 0.35
block_labels = [
"FEMALE_GENITALIA_COVERED",
"BUTTOCKS_EXPOSED",
"FEMALE_BREAST_EXPOSED",
"FEMALE_GENITALIA_EXPOSED",
"MALE_GENITALIA_EXPOSED",
"ANUS_EXPOSED",
"ANUS_COVERED"
]
class Config(BaseProxyConfig):
"""
Configuration manager for the NSFWModelPlugin.
"""
def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("max_concurrent_jobs")
helper.copy("via_servers")
helper.copy("actions")
class NSFWModelPlugin(Plugin):
semaphore = Semaphore(1)
via_servers = []
actions = {}
report_to_room = ""
# Track images sent by each user
user_image_data = defaultdict(list) # {user_id: [(timestamp, event_id)]}
max_images = 3 # Max number of images
time_window = 60 * 1 # Time window in seconds (e.g., 5 minutes)
@classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]:
return Config
async def start(self) -> None:
await super().start()
try:
if not isinstance(self.config, Config):
self.log.error("Plugin not yet configured.")
else:
self.config.load_and_update()
self.via_servers = self.config["via_servers"]
self.actions = self.config["actions"]
max_concurrent_jobs = self.config["max_concurrent_jobs"]
self.semaphore = Semaphore(max_concurrent_jobs)
self.report_to_room = str(self.actions.get("report_to_room", ""))
if self.report_to_room.startswith("#"):
report_to_info = await self.client.resolve_room_alias(RoomAlias(self.report_to_room))
self.report_to_room = report_to_info.room_id
elif self.report_to_room and not self.report_to_room.startswith("!"):
self.log.warning("Invalid room ID or alias provided for report_to_room")
self.log.info("Loaded nsfwbot successfully")
except Exception as e:
self.log.error(f"Error during start: {e}")
@command.passive(
"^mxc://.+/.+$",
field=lambda evt: evt.content.url or "", # type:ignore
msgtypes=(MessageType.IMAGE,),
)
async def handle_image_message(self, evt: MessageEvent, url: Tuple[str]) -> None:
"""
Handle direct image messages with rate limiting and redact previous images.
"""
user_id = evt.sender # The user who sent the image
current_time = time.time()
# Clean up old entries that are outside of the time window
self.user_image_data[user_id] = [
(timestamp, event_id) for (timestamp, event_id) in self.user_image_data[user_id]
if current_time - timestamp <= self.time_window
]
# Check if user exceeded the image limit
if len(self.user_image_data[user_id]) >= self.max_images:
# Redact all images sent within the time window
for _, event_id in self.user_image_data[user_id]:
try:
await self.client.redact(evt.room_id, event_id, reason="Too many images sent in a short period")
self.log.info(f"Redacted image sent by {user_id} (event ID: {event_id}) due to rate limit.")
except Exception as e:
self.log.error(f"Failed to redact image (event ID: {event_id}): {e}")
# Also redact the current image
await self.client.redact(evt.room_id, evt.event_id, reason="Too many images sent in a short period")
self.log.warning(f"User {user_id} exceeded the image limit. Current image redacted.")
return
# Add current timestamp and event ID
self.user_image_data[user_id].append((current_time, evt.event_id))
results = None
try:
if not isinstance(evt.content, MediaMessageEventContent) or not evt.content.url:
return
results = await self.process_images([evt.content.url])
if results is None:
return
matrix_to_url = self.create_matrix_to_url(evt.room_id, evt.event_id)
response = self.format_response(results, matrix_to_url, evt.sender)
await self.send_responses(evt, response, results)
except Exception as e:
self.log.error(f"Error handling image message: {e}")
async def process_images(self, mxc_urls: List[str]) -> dict:
"""
Download and process the images using the NudeNet detector.
:param mxc_urls: List of MXC URLs of the images.
:return: Dictionary of results with MXC URLs as keys and detection results as values.
"""
async with self.semaphore:
temp_files = []
try:
# Download and temporarily save the images
for mxc_url in mxc_urls:
img_bytes = await self.client.download_media(mxc_url) # Download image from MXC URL
temp_filename = f"/tmp/{uuid4()}.jpg"
with open(temp_filename, "wb") as img_file:
img_file.write(img_bytes)
temp_files.append((mxc_url, temp_filename))
# Process images using NudeNet
results = {}
for mxc_url, temp_filename in temp_files:
detection_result = detector.detect(temp_filename) # Detect NSFW content with bounding boxes
results[mxc_url] = detection_result
return results
except Exception as e:
self.log.error(f"Error processing images: {e}")
return {}
finally:
# Clean up temporary files
for _, temp_filename in temp_files:
os.remove(temp_filename)
def format_response(self, results: dict, matrix_to_url: str, sender: str) -> str:
"""
Format the response message based on the results.
:param results: Dictionary of results with MXC URLs as keys and detection results as values.
:param matrix_to_url: The matrix.to URL for the original message.
:return: The formatted response message.
"""
response_parts = []
for mxc_url, detections in results.items():
for detection in detections:
if detection['class'] in block_labels:
detection_info = (
f"{mxc_url} contains {detection['class']} "
f"with a score of {detection['score']:.2f} "
f"in {matrix_to_url} "
f"by {sender}"
)
response_parts.append(detection_info)
break
return "\n".join(response_parts)
async def send_responses(self, evt: MessageEvent, response: str, results: dict) -> None:
"""
Send responses or take actions based on config.
:param evt: The message event.
:param response: The formatted response message.
:param results: Dictionary of results with MXC URLs as keys and detection results as values.
"""
try:
ignore_sfw = self.actions.get("ignore_sfw", False)
nsfw_results = [res for res in results.values() if any(d['class'] != 'SFW' for d in res)]
detected = 0
for mxc_url, detections in results.items():
for detection in detections:
if detection['class'] in block_labels:
below_min_score = False
if detection['score'] <= min_score:
self.log.info(f"{mxc_url} in class {detection['class']} is SFW because score {detection['score']} is below minimum {min_score} score.")
below_min_score = True
if not below_min_score:
detected = 1
self.log.info(f"{mxc_url} is NSFW because {detection['class']} is blocked (score: {detection['score']})")
if detected == 0:
self.log.info(f"{evt.room_id} is SFW")
nsfw_results = False
redact_nsfw = self.actions.get("redact_nsfw", False)
if ignore_sfw and not nsfw_results:
self.log.info(f"Ignored SFW images in {evt.room_id}")
return
if self.actions.get("direct_reply", False):
await evt.reply(response)
if self.report_to_room:
try:
await self.client.send_text(room_id=RoomID(self.report_to_room), text=response)
except MBadJSON as e:
self.log.warning(f"Failed to send message to {RoomID(self.report_to_room)}: {e}")
if nsfw_results and redact_nsfw:
try:
await self.client.redact(room_id=evt.room_id, event_id=evt.event_id, reason="NSFW")
self.log.info(f"Redacted NSFW message in {evt.room_id}")
except MForbidden:
self.log.warning(f"Failed to redact NSFW message in {evt.room_id}")
except Exception as e:
self.log.error(f"Error sending responses: {e}")
def create_matrix_to_url(self, room_id: RoomID, event_id: EventID) -> str:
"""
Create a matrix.to URL for a given room ID and event ID.
:param room_id: The room ID.
:param event_id: The event ID.
:return: The matrix.to URL.
"""
via_params = (
str("?" + "&".join([f"via={server}" for server in self.via_servers]))
if self.via_servers
else ""
)
return f"https://matrix.to/#/{room_id}/{event_id}{via_params}"
def extract_img_tags(self, html: str) -> List[str]:
"""
Extract image URLs from <img> tags in the HTML content.
:param html: The HTML content.
:return: List of image URLs.
"""
soup = BeautifulSoup(html, "html.parser")
return [img["src"] for img in soup.find_all("img") if "src" in img.attrs]