some fixes for detection
This commit is contained in:
parent
49a0b9dd89
commit
02da96150d
2 changed files with 14 additions and 8 deletions
|
@ -1,6 +1,6 @@
|
||||||
maubot: 0.1.0
|
maubot: 0.1.0
|
||||||
id: uk.tcpip.nsfwbot
|
id: uk.tcpip.nsfwbot
|
||||||
version: 0.2.4
|
version: 0.2.5
|
||||||
license: AGPL-3.0-or-later
|
license: AGPL-3.0-or-later
|
||||||
modules:
|
modules:
|
||||||
- nsfwbot
|
- nsfwbot
|
||||||
|
|
20
nsfwbot.py
20
nsfwbot.py
|
@ -17,16 +17,15 @@ from maubot.handlers import command, event
|
||||||
|
|
||||||
# Initialize NudeDetector
|
# Initialize NudeDetector
|
||||||
detector = NudeDetector()
|
detector = NudeDetector()
|
||||||
|
min_score = 0.35
|
||||||
block_labels = [
|
block_labels = [
|
||||||
"FEMALE_GENITALIA_COVERED",
|
"FEMALE_GENITALIA_COVERED",
|
||||||
"BUTTOCKS_EXPOSED",
|
"BUTTOCKS_EXPOSED",
|
||||||
"FEMALE_BREAST_EXPOSED",
|
"FEMALE_BREAST_EXPOSED",
|
||||||
"FEMALE_GENITALIA_EXPOSED",
|
"FEMALE_GENITALIA_EXPOSED",
|
||||||
"ANUS_EXPOSED",
|
|
||||||
"MALE_GENITALIA_EXPOSED",
|
"MALE_GENITALIA_EXPOSED",
|
||||||
"ANUS_COVERED",
|
"ANUS_EXPOSED",
|
||||||
"BUTTOCKS_COVERED",
|
"ANUS_COVERED"
|
||||||
]
|
]
|
||||||
|
|
||||||
class Config(BaseProxyConfig):
|
class Config(BaseProxyConfig):
|
||||||
|
@ -110,11 +109,13 @@ class NSFWModelPlugin(Plugin):
|
||||||
|
|
||||||
# Add current timestamp and event ID
|
# Add current timestamp and event ID
|
||||||
self.user_image_data[user_id].append((current_time, evt.event_id))
|
self.user_image_data[user_id].append((current_time, evt.event_id))
|
||||||
|
results = None
|
||||||
try:
|
try:
|
||||||
if not isinstance(evt.content, MediaMessageEventContent) or not evt.content.url:
|
if not isinstance(evt.content, MediaMessageEventContent) or not evt.content.url:
|
||||||
return
|
return
|
||||||
results = await self.process_images([evt.content.url])
|
results = await self.process_images([evt.content.url])
|
||||||
|
if results is None:
|
||||||
|
return
|
||||||
matrix_to_url = self.create_matrix_to_url(evt.room_id, evt.event_id)
|
matrix_to_url = self.create_matrix_to_url(evt.room_id, evt.event_id)
|
||||||
response = self.format_response(results, matrix_to_url, evt.sender)
|
response = self.format_response(results, matrix_to_url, evt.sender)
|
||||||
await self.send_responses(evt, response, results)
|
await self.send_responses(evt, response, results)
|
||||||
|
@ -195,8 +196,13 @@ class NSFWModelPlugin(Plugin):
|
||||||
for mxc_url, detections in results.items():
|
for mxc_url, detections in results.items():
|
||||||
for detection in detections:
|
for detection in detections:
|
||||||
if detection['class'] in block_labels:
|
if detection['class'] in block_labels:
|
||||||
detected = 1
|
below_min_score = False
|
||||||
self.log.info(f"{mxc_url} is NSFW because {detection['class']} is blocked")
|
if detection['score'] <= min_score:
|
||||||
|
self.log.info(f"{mxc_url} in class {detection['class']} is SFW because score {detection['score']} is below minimum {min_score} score.")
|
||||||
|
below_min_score = True
|
||||||
|
if not below_min_score:
|
||||||
|
detected = 1
|
||||||
|
self.log.info(f"{mxc_url} is NSFW because {detection['class']} is blocked (score: {detection['score']})")
|
||||||
if detected == 0:
|
if detected == 0:
|
||||||
self.log.info(f"{evt.room_id} is SFW")
|
self.log.info(f"{evt.room_id} is SFW")
|
||||||
nsfw_results = False
|
nsfw_results = False
|
||||||
|
|
Loading…
Reference in a new issue