250 lines
10 KiB
Python
250 lines
10 KiB
Python
import os
|
|
import time
|
|
from bs4 import BeautifulSoup
|
|
from collections import defaultdict
|
|
from maubot import Plugin
|
|
from nudenet import NudeDetector
|
|
from uuid import uuid4
|
|
from typing import List, Type, Tuple
|
|
from asyncio import Semaphore
|
|
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
|
|
from mautrix.types import (
|
|
MessageEvent, MessageType, RoomAlias, RoomID, EventID, TextMessageEventContent, MediaMessageEventContent
|
|
)
|
|
from mautrix.errors import MBadJSON, MForbidden
|
|
from maubot.handlers import command, event
|
|
|
|
|
|
# Initialize NudeDetector
|
|
detector = NudeDetector()
|
|
min_score = 0.35
|
|
block_labels = [
|
|
"FEMALE_GENITALIA_COVERED",
|
|
"BUTTOCKS_EXPOSED",
|
|
"FEMALE_BREAST_EXPOSED",
|
|
"FEMALE_GENITALIA_EXPOSED",
|
|
"MALE_GENITALIA_EXPOSED",
|
|
"ANUS_EXPOSED",
|
|
"ANUS_COVERED"
|
|
]
|
|
|
|
class Config(BaseProxyConfig):
|
|
"""
|
|
Configuration manager for the NSFWModelPlugin.
|
|
"""
|
|
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
|
helper.copy("max_concurrent_jobs")
|
|
helper.copy("via_servers")
|
|
helper.copy("actions")
|
|
|
|
|
|
class NSFWModelPlugin(Plugin):
|
|
semaphore = Semaphore(1)
|
|
via_servers = []
|
|
actions = {}
|
|
report_to_room = ""
|
|
|
|
# Track images sent by each user
|
|
user_image_data = defaultdict(list) # {user_id: [(timestamp, event_id)]}
|
|
max_images = 3 # Max number of images
|
|
time_window = 60 * 1 # Time window in seconds (e.g., 5 minutes)
|
|
|
|
@classmethod
|
|
def get_config_class(cls) -> Type[BaseProxyConfig]:
|
|
return Config
|
|
|
|
async def start(self) -> None:
|
|
await super().start()
|
|
try:
|
|
if not isinstance(self.config, Config):
|
|
self.log.error("Plugin not yet configured.")
|
|
else:
|
|
self.config.load_and_update()
|
|
self.via_servers = self.config["via_servers"]
|
|
self.actions = self.config["actions"]
|
|
max_concurrent_jobs = self.config["max_concurrent_jobs"]
|
|
self.semaphore = Semaphore(max_concurrent_jobs)
|
|
self.report_to_room = str(self.actions.get("report_to_room", ""))
|
|
if self.report_to_room.startswith("#"):
|
|
report_to_info = await self.client.resolve_room_alias(RoomAlias(self.report_to_room))
|
|
self.report_to_room = report_to_info.room_id
|
|
elif self.report_to_room and not self.report_to_room.startswith("!"):
|
|
self.log.warning("Invalid room ID or alias provided for report_to_room")
|
|
self.log.info("Loaded nsfwbot successfully")
|
|
except Exception as e:
|
|
self.log.error(f"Error during start: {e}")
|
|
|
|
@command.passive(
|
|
"^mxc://.+/.+$",
|
|
field=lambda evt: evt.content.url or "", # type:ignore
|
|
msgtypes=(MessageType.IMAGE,),
|
|
)
|
|
async def handle_image_message(self, evt: MessageEvent, url: Tuple[str]) -> None:
|
|
"""
|
|
Handle direct image messages with rate limiting and redact previous images.
|
|
"""
|
|
user_id = evt.sender # The user who sent the image
|
|
current_time = time.time()
|
|
|
|
# Clean up old entries that are outside of the time window
|
|
self.user_image_data[user_id] = [
|
|
(timestamp, event_id) for (timestamp, event_id) in self.user_image_data[user_id]
|
|
if current_time - timestamp <= self.time_window
|
|
]
|
|
|
|
# Check if user exceeded the image limit
|
|
if len(self.user_image_data[user_id]) >= self.max_images:
|
|
# Redact all images sent within the time window
|
|
for _, event_id in self.user_image_data[user_id]:
|
|
try:
|
|
await self.client.redact(evt.room_id, event_id, reason="Too many images sent in a short period")
|
|
self.log.info(f"Redacted image sent by {user_id} (event ID: {event_id}) due to rate limit.")
|
|
except Exception as e:
|
|
self.log.error(f"Failed to redact image (event ID: {event_id}): {e}")
|
|
|
|
# Also redact the current image
|
|
await self.client.redact(evt.room_id, evt.event_id, reason="Too many images sent in a short period")
|
|
self.log.warning(f"User {user_id} exceeded the image limit. Current image redacted.")
|
|
return
|
|
|
|
# Add current timestamp and event ID
|
|
self.user_image_data[user_id].append((current_time, evt.event_id))
|
|
results = None
|
|
try:
|
|
if not isinstance(evt.content, MediaMessageEventContent) or not evt.content.url:
|
|
return
|
|
results = await self.process_images([evt.content.url])
|
|
if results is None:
|
|
return
|
|
matrix_to_url = self.create_matrix_to_url(evt.room_id, evt.event_id)
|
|
response = self.format_response(results, matrix_to_url, evt.sender)
|
|
await self.send_responses(evt, response, results)
|
|
except Exception as e:
|
|
self.log.error(f"Error handling image message: {e}")
|
|
|
|
|
|
async def process_images(self, mxc_urls: List[str]) -> dict:
|
|
"""
|
|
Download and process the images using the NudeNet detector.
|
|
|
|
:param mxc_urls: List of MXC URLs of the images.
|
|
:return: Dictionary of results with MXC URLs as keys and detection results as values.
|
|
"""
|
|
async with self.semaphore:
|
|
temp_files = []
|
|
try:
|
|
# Download and temporarily save the images
|
|
for mxc_url in mxc_urls:
|
|
img_bytes = await self.client.download_media(mxc_url) # Download image from MXC URL
|
|
temp_filename = f"/tmp/{uuid4()}.jpg"
|
|
with open(temp_filename, "wb") as img_file:
|
|
img_file.write(img_bytes)
|
|
temp_files.append((mxc_url, temp_filename))
|
|
|
|
# Process images using NudeNet
|
|
results = {}
|
|
for mxc_url, temp_filename in temp_files:
|
|
detection_result = detector.detect(temp_filename) # Detect NSFW content with bounding boxes
|
|
results[mxc_url] = detection_result
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
self.log.error(f"Error processing images: {e}")
|
|
return {}
|
|
finally:
|
|
# Clean up temporary files
|
|
for _, temp_filename in temp_files:
|
|
os.remove(temp_filename)
|
|
|
|
def format_response(self, results: dict, matrix_to_url: str, sender: str) -> str:
|
|
"""
|
|
Format the response message based on the results.
|
|
|
|
:param results: Dictionary of results with MXC URLs as keys and detection results as values.
|
|
:param matrix_to_url: The matrix.to URL for the original message.
|
|
:return: The formatted response message.
|
|
"""
|
|
response_parts = []
|
|
for mxc_url, detections in results.items():
|
|
for detection in detections:
|
|
if detection['class'] in block_labels:
|
|
detection_info = (
|
|
f"{mxc_url} contains {detection['class']} "
|
|
f"with a score of {detection['score']:.2f} "
|
|
f"in {matrix_to_url} "
|
|
f"by {sender}"
|
|
)
|
|
response_parts.append(detection_info)
|
|
break
|
|
|
|
return "\n".join(response_parts)
|
|
|
|
async def send_responses(self, evt: MessageEvent, response: str, results: dict) -> None:
|
|
"""
|
|
Send responses or take actions based on config.
|
|
|
|
:param evt: The message event.
|
|
:param response: The formatted response message.
|
|
:param results: Dictionary of results with MXC URLs as keys and detection results as values.
|
|
"""
|
|
|
|
try:
|
|
ignore_sfw = self.actions.get("ignore_sfw", False)
|
|
nsfw_results = [res for res in results.values() if any(d['class'] != 'SFW' for d in res)]
|
|
detected = 0
|
|
for mxc_url, detections in results.items():
|
|
for detection in detections:
|
|
if detection['class'] in block_labels:
|
|
below_min_score = False
|
|
if detection['score'] <= min_score:
|
|
self.log.info(f"{mxc_url} in class {detection['class']} is SFW because score {detection['score']} is below minimum {min_score} score.")
|
|
below_min_score = True
|
|
if not below_min_score:
|
|
detected = 1
|
|
self.log.info(f"{mxc_url} is NSFW because {detection['class']} is blocked (score: {detection['score']})")
|
|
if detected == 0:
|
|
self.log.info(f"{evt.room_id} is SFW")
|
|
nsfw_results = False
|
|
redact_nsfw = self.actions.get("redact_nsfw", False)
|
|
if ignore_sfw and not nsfw_results:
|
|
self.log.info(f"Ignored SFW images in {evt.room_id}")
|
|
return
|
|
if self.actions.get("direct_reply", False):
|
|
await evt.reply(response)
|
|
if self.report_to_room:
|
|
try:
|
|
await self.client.send_text(room_id=RoomID(self.report_to_room), text=response)
|
|
except MBadJSON as e:
|
|
self.log.warning(f"Failed to send message to {RoomID(self.report_to_room)}: {e}")
|
|
if nsfw_results and redact_nsfw:
|
|
try:
|
|
await self.client.redact(room_id=evt.room_id, event_id=evt.event_id, reason="NSFW")
|
|
self.log.info(f"Redacted NSFW message in {evt.room_id}")
|
|
except MForbidden:
|
|
self.log.warning(f"Failed to redact NSFW message in {evt.room_id}")
|
|
except Exception as e:
|
|
self.log.error(f"Error sending responses: {e}")
|
|
def create_matrix_to_url(self, room_id: RoomID, event_id: EventID) -> str:
|
|
"""
|
|
Create a matrix.to URL for a given room ID and event ID.
|
|
|
|
:param room_id: The room ID.
|
|
:param event_id: The event ID.
|
|
:return: The matrix.to URL.
|
|
"""
|
|
via_params = (
|
|
str("?" + "&".join([f"via={server}" for server in self.via_servers]))
|
|
if self.via_servers
|
|
else ""
|
|
)
|
|
return f"https://matrix.to/#/{room_id}/{event_id}{via_params}"
|
|
def extract_img_tags(self, html: str) -> List[str]:
|
|
"""
|
|
Extract image URLs from <img> tags in the HTML content.
|
|
|
|
:param html: The HTML content.
|
|
:return: List of image URLs.
|
|
"""
|
|
soup = BeautifulSoup(html, "html.parser")
|
|
return [img["src"] for img in soup.find_all("img") if "src" in img.attrs]
|