some fixes for detection

This commit is contained in:
Nils Büchner 2024-11-02 16:34:20 +01:00
parent 49a0b9dd89
commit 02da96150d
2 changed files with 14 additions and 8 deletions

View file

@ -1,6 +1,6 @@
maubot: 0.1.0 maubot: 0.1.0
id: uk.tcpip.nsfwbot id: uk.tcpip.nsfwbot
version: 0.2.4 version: 0.2.5
license: AGPL-3.0-or-later license: AGPL-3.0-or-later
modules: modules:
- nsfwbot - nsfwbot

View file

@ -17,16 +17,15 @@ from maubot.handlers import command, event
# Initialize NudeDetector # Initialize NudeDetector
detector = NudeDetector() detector = NudeDetector()
min_score = 0.35
block_labels = [ block_labels = [
"FEMALE_GENITALIA_COVERED", "FEMALE_GENITALIA_COVERED",
"BUTTOCKS_EXPOSED", "BUTTOCKS_EXPOSED",
"FEMALE_BREAST_EXPOSED", "FEMALE_BREAST_EXPOSED",
"FEMALE_GENITALIA_EXPOSED", "FEMALE_GENITALIA_EXPOSED",
"ANUS_EXPOSED",
"MALE_GENITALIA_EXPOSED", "MALE_GENITALIA_EXPOSED",
"ANUS_COVERED", "ANUS_EXPOSED",
"BUTTOCKS_COVERED", "ANUS_COVERED"
] ]
class Config(BaseProxyConfig): class Config(BaseProxyConfig):
@ -110,11 +109,13 @@ class NSFWModelPlugin(Plugin):
# Add current timestamp and event ID # Add current timestamp and event ID
self.user_image_data[user_id].append((current_time, evt.event_id)) self.user_image_data[user_id].append((current_time, evt.event_id))
results = None
try: try:
if not isinstance(evt.content, MediaMessageEventContent) or not evt.content.url: if not isinstance(evt.content, MediaMessageEventContent) or not evt.content.url:
return return
results = await self.process_images([evt.content.url]) results = await self.process_images([evt.content.url])
if results is None:
return
matrix_to_url = self.create_matrix_to_url(evt.room_id, evt.event_id) matrix_to_url = self.create_matrix_to_url(evt.room_id, evt.event_id)
response = self.format_response(results, matrix_to_url, evt.sender) response = self.format_response(results, matrix_to_url, evt.sender)
await self.send_responses(evt, response, results) await self.send_responses(evt, response, results)
@ -195,8 +196,13 @@ class NSFWModelPlugin(Plugin):
for mxc_url, detections in results.items(): for mxc_url, detections in results.items():
for detection in detections: for detection in detections:
if detection['class'] in block_labels: if detection['class'] in block_labels:
detected = 1 below_min_score = False
self.log.info(f"{mxc_url} is NSFW because {detection['class']} is blocked") if detection['score'] <= min_score:
self.log.info(f"{mxc_url} in class {detection['class']} is SFW because score {detection['score']} is below minimum {min_score} score.")
below_min_score = True
if not below_min_score:
detected = 1
self.log.info(f"{mxc_url} is NSFW because {detection['class']} is blocked (score: {detection['score']})")
if detected == 0: if detected == 0:
self.log.info(f"{evt.room_id} is SFW") self.log.info(f"{evt.room_id} is SFW")
nsfw_results = False nsfw_results = False