From 02da96150d2c22e248d9b1b8e1f27cbbc0583978 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nils=20B=C3=BCchner?= Date: Sat, 2 Nov 2024 16:34:20 +0100 Subject: [PATCH] some fixes for detection --- maubot.yaml | 2 +- nsfwbot.py | 20 +++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/maubot.yaml b/maubot.yaml index af9494f..7d750dd 100644 --- a/maubot.yaml +++ b/maubot.yaml @@ -1,6 +1,6 @@ maubot: 0.1.0 id: uk.tcpip.nsfwbot -version: 0.2.4 +version: 0.2.5 license: AGPL-3.0-or-later modules: - nsfwbot diff --git a/nsfwbot.py b/nsfwbot.py index a2d172f..942c185 100644 --- a/nsfwbot.py +++ b/nsfwbot.py @@ -17,16 +17,15 @@ from maubot.handlers import command, event # Initialize NudeDetector detector = NudeDetector() - +min_score = 0.35 block_labels = [ "FEMALE_GENITALIA_COVERED", "BUTTOCKS_EXPOSED", "FEMALE_BREAST_EXPOSED", "FEMALE_GENITALIA_EXPOSED", - "ANUS_EXPOSED", "MALE_GENITALIA_EXPOSED", - "ANUS_COVERED", - "BUTTOCKS_COVERED", + "ANUS_EXPOSED", + "ANUS_COVERED" ] class Config(BaseProxyConfig): @@ -110,11 +109,13 @@ class NSFWModelPlugin(Plugin): # Add current timestamp and event ID self.user_image_data[user_id].append((current_time, evt.event_id)) - + results = None try: if not isinstance(evt.content, MediaMessageEventContent) or not evt.content.url: return results = await self.process_images([evt.content.url]) + if results is None: + return matrix_to_url = self.create_matrix_to_url(evt.room_id, evt.event_id) response = self.format_response(results, matrix_to_url, evt.sender) await self.send_responses(evt, response, results) @@ -195,8 +196,13 @@ class NSFWModelPlugin(Plugin): for mxc_url, detections in results.items(): for detection in detections: if detection['class'] in block_labels: - detected = 1 - self.log.info(f"{mxc_url} is NSFW because {detection['class']} is blocked") + below_min_score = False + if detection['score'] <= min_score: + self.log.info(f"{mxc_url} in class {detection['class']} is SFW because score {detection['score']} is below minimum {min_score} score.") + below_min_score = True + if not below_min_score: + detected = 1 + self.log.info(f"{mxc_url} is NSFW because {detection['class']} is blocked (score: {detection['score']})") if detected == 0: self.log.info(f"{evt.room_id} is SFW") nsfw_results = False