a
This commit is contained in:
492
step3_infer_speakers.py
Normal file
492
step3_infer_speakers.py
Normal file
@@ -0,0 +1,492 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Step 3: Use LLM to infer speaker names from transcript context.
|
||||
|
||||
Input: Line-formatted files in "_lines/" folder
|
||||
Output: Files with inferred speaker names in "_speakers/" folder
|
||||
|
||||
This version uses a multi-step approach:
|
||||
1. First identify Malabar (he's always present)
|
||||
2. Then identify each remaining speaker one by one
|
||||
3. Each step saves debug info to _speakers_debug/
|
||||
|
||||
Usage:
|
||||
uv run step3_infer_speakers.py
|
||||
|
||||
Environment Variables:
|
||||
OPENAI_API_KEY - Required (can be OpenAI, Kimi, or GLM key)
|
||||
OPENAI_BASE_URL - Optional (for Kimi/GLM APIs)
|
||||
LLM_MODEL - Optional (e.g., "glm-4.5-air", "kimi-latest")
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from openai import OpenAI
|
||||
|
||||
# ============== Configuration ==============
|
||||
|
||||
INPUT_DIR = Path("_lines")
|
||||
OUTPUT_DIR = Path("_speakers")
|
||||
DEBUG_DIR = Path("_speakers_debug")
|
||||
|
||||
# Examples of good speaker names (for reference, not a restricted list)
|
||||
NAME_EXAMPLES = ["Malabar", "Sun", "Jupiter", "Kangaroo", "Mole"]
|
||||
|
||||
# Default configurations for different providers
|
||||
DEFAULT_CONFIGS = {
|
||||
"openai": {
|
||||
"base_url": None,
|
||||
"model": "gpt-4o-mini"
|
||||
},
|
||||
"moonshot": {
|
||||
"base_url": "https://api.moonshot.cn/v1",
|
||||
"model": "kimi-latest"
|
||||
},
|
||||
"bigmodel": { # Zhipu AI (GLM)
|
||||
"base_url": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"model": "glm-4.5-air"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def ensure_dirs():
|
||||
"""Ensure output directories exist."""
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
DEBUG_DIR.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def get_llm_config() -> Tuple[str, str]:
|
||||
"""Get LLM configuration from environment."""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("OPENAI_API_KEY environment variable is required")
|
||||
|
||||
base_url = os.getenv("OPENAI_BASE_URL")
|
||||
model = os.getenv("LLM_MODEL")
|
||||
|
||||
if base_url:
|
||||
if model:
|
||||
return base_url, model
|
||||
if "bigmodel" in base_url:
|
||||
return base_url, DEFAULT_CONFIGS["bigmodel"]["model"]
|
||||
elif "moonshot" in base_url or "kimi" in base_url:
|
||||
return base_url, DEFAULT_CONFIGS["moonshot"]["model"]
|
||||
else:
|
||||
return base_url, DEFAULT_CONFIGS["openai"]["model"]
|
||||
else:
|
||||
return None, model or DEFAULT_CONFIGS["openai"]["model"]
|
||||
|
||||
|
||||
def parse_lines(lines_text: str) -> List[Tuple[str, str, str]]:
|
||||
"""Parse formatted lines. Returns list of (timestamp, speaker_label, text)."""
|
||||
pattern = r'^(\[\d{2}:\d{2}\])\(Speaker ([A-Z?])\) (.+)$'
|
||||
result = []
|
||||
|
||||
for line in lines_text.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
match = re.match(pattern, line)
|
||||
if match:
|
||||
timestamp = match.group(1)
|
||||
speaker = match.group(2)
|
||||
text = match.group(3)
|
||||
result.append((timestamp, speaker, text))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_timestamp(ts: str) -> int:
|
||||
"""Parse [mm:ss] timestamp to total seconds."""
|
||||
match = re.match(r'\[(\d{2}):(\d{2})\]', ts)
|
||||
if match:
|
||||
minutes = int(match.group(1))
|
||||
seconds = int(match.group(2))
|
||||
return minutes * 60 + seconds
|
||||
return 0
|
||||
|
||||
|
||||
def classify_speakers_by_time(lines: List[Tuple[str, str, str]]) -> Tuple[set, set]:
|
||||
"""Classify speakers based on when they appear."""
|
||||
all_speakers = set(speaker for _, speaker, _ in lines)
|
||||
|
||||
song_speakers = set()
|
||||
dialogue_speakers = set()
|
||||
|
||||
for speaker in all_speakers:
|
||||
has_lines_after_15 = any(
|
||||
parse_timestamp(ts) > 15 and spk == speaker
|
||||
for ts, spk, _ in lines
|
||||
)
|
||||
if has_lines_after_15:
|
||||
dialogue_speakers.add(speaker)
|
||||
else:
|
||||
has_lines_in_first_15 = any(
|
||||
parse_timestamp(ts) <= 15 and spk == speaker
|
||||
for ts, spk, _ in lines
|
||||
)
|
||||
if has_lines_in_first_15:
|
||||
song_speakers.add(speaker)
|
||||
|
||||
return song_speakers, dialogue_speakers
|
||||
|
||||
|
||||
def format_dialogue_with_names(lines: List[Tuple[str, str, str]], speaker_names: Dict[str, str]) -> str:
|
||||
"""Format dialogue lines with known speaker names."""
|
||||
result_lines = []
|
||||
for timestamp, speaker, text in lines:
|
||||
# Skip lines in first 15 seconds (opening song)
|
||||
if parse_timestamp(timestamp) <= 15:
|
||||
continue
|
||||
name = speaker_names.get(speaker, f"Speaker_{speaker}")
|
||||
result_lines.append(f'{timestamp}({name}) {text}')
|
||||
return '\n'.join(result_lines)
|
||||
|
||||
|
||||
def save_debug(filename: str, request: str, response: str, step: int):
|
||||
"""Save debug info to _speakers_debug folder."""
|
||||
debug_file = DEBUG_DIR / f"{filename}_step{step}.txt"
|
||||
with open(debug_file, 'w', encoding='utf-8') as f:
|
||||
f.write("=" * 60 + "\n")
|
||||
f.write("REQUEST:\n")
|
||||
f.write("=" * 60 + "\n\n")
|
||||
f.write(request)
|
||||
f.write("\n\n")
|
||||
f.write("=" * 60 + "\n")
|
||||
f.write("RESPONSE:\n")
|
||||
f.write("=" * 60 + "\n\n")
|
||||
f.write(response)
|
||||
|
||||
|
||||
def extract_name_from_response(text: str) -> str:
|
||||
"""Extract a single name from LLM response text."""
|
||||
text = text.strip()
|
||||
|
||||
# Expanded list of valid names - includes celestial bodies and other entities
|
||||
valid_names = ['Malabar', 'Moon', 'Earth', 'Mars', 'Sun', 'Jupiter', 'Saturn', 'Venus',
|
||||
'Mercury', 'Neptune', 'Uranus', 'Pluto', 'Galaxy', 'Star', 'Kangaroo',
|
||||
'Giraffe', 'Volcano', 'Volcanoes', 'Sea', 'Ocean', 'Wave', 'Comet',
|
||||
'Asteroid', 'Meteor', 'Nebula', 'Black Hole', 'Alien', 'Robot', 'Scientist']
|
||||
|
||||
# Check if the response is just a single word (the name)
|
||||
if ' ' not in text and len(text) > 1:
|
||||
return text.strip('"\'')
|
||||
|
||||
# Look for explicit "Answer: X" or "Name: X" patterns
|
||||
answer_match = re.search(r'(?:answer|name|is)[:\s]+["\']?([A-Z][a-z]+)', text, re.IGNORECASE)
|
||||
if answer_match:
|
||||
return answer_match.group(1)
|
||||
|
||||
# Check last few lines for a valid name
|
||||
lines = text.split('\n')
|
||||
for line in reversed(lines[-5:]): # Check last 5 lines
|
||||
line = line.strip().strip('"\'')
|
||||
for name in valid_names:
|
||||
if line.lower() == name.lower():
|
||||
return name
|
||||
if re.search(rf'\b{name}\b', line, re.IGNORECASE):
|
||||
return name
|
||||
|
||||
# Default: return first valid name found
|
||||
for name in valid_names:
|
||||
if re.search(rf'\b{name}\b', text, re.IGNORECASE):
|
||||
return name
|
||||
|
||||
# If no known name found, extract any capitalized word as potential name
|
||||
for line in text.split('\n'):
|
||||
line = line.strip()
|
||||
match = re.search(r'\b([A-Z][a-z]{2,})\b', line)
|
||||
if match:
|
||||
word = match.group(1)
|
||||
if word.lower() not in ['the', 'and', 'but', 'for', 'are', 'was', 'were', 'been', 'this', 'that']:
|
||||
return word
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def ask_llm_for_name(prompt: str, client: OpenAI, model: str, debug_filename: str, step: int) -> str:
|
||||
"""Ask LLM for a single name. Returns the name or empty string if failed."""
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": "Reply with ONLY a single word - the name. No explanation."},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.0,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
message = response.choices[0].message
|
||||
raw_result = message.content or ""
|
||||
|
||||
# If content is empty but reasoning_content exists, use that
|
||||
if not raw_result and hasattr(message, 'reasoning_content') and message.reasoning_content:
|
||||
raw_result = message.reasoning_content
|
||||
|
||||
# Extract name from the response
|
||||
result = extract_name_from_response(raw_result)
|
||||
|
||||
# Save debug info
|
||||
save_debug(debug_filename, prompt, f"RAW: {raw_result[:800]}\n\nEXTRACTED: {result}", step)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
save_debug(debug_filename, prompt, f"ERROR: {e}", step)
|
||||
return ""
|
||||
|
||||
|
||||
def identify_malabar(dialogue_lines: List[Tuple[str, str, str]],
|
||||
client: OpenAI, model: str, debug_filename: str) -> Optional[str]:
|
||||
"""Identify which speaker is Malabar."""
|
||||
# Only consider single-letter speakers (exclude "?" and other special markers)
|
||||
speakers = sorted(set(speaker for _, speaker, _ in dialogue_lines
|
||||
if parse_timestamp(_) > 15 and len(speaker) == 1 and speaker.isalpha()))
|
||||
|
||||
if not speakers:
|
||||
return None
|
||||
|
||||
# Get sample lines from each speaker
|
||||
samples = []
|
||||
for speaker in speakers:
|
||||
lines = [(ts, text) for ts, spk, text in dialogue_lines
|
||||
if spk == speaker and parse_timestamp(ts) > 15][:3]
|
||||
for ts, text in lines:
|
||||
samples.append(f'{speaker}: "{text}"')
|
||||
|
||||
sample_text = '\n'.join(samples)
|
||||
|
||||
prompt = f"""Little Malabar dialogue. Malabar is the boy who addresses Kangaroo/Giraffe.
|
||||
|
||||
{sample_text}
|
||||
|
||||
Which speaker letter is Malabar? Reply with ONLY A, B, or C:"""
|
||||
|
||||
result = ask_llm_for_name(prompt, client, model, debug_filename, 1)
|
||||
|
||||
# Extract the letter
|
||||
match = re.search(r'\b([A-Z])\b', result.upper())
|
||||
if match and match.group(1) in speakers:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def identify_speaker(speaker: str,
|
||||
dialogue_lines: List[Tuple[str, str, str]],
|
||||
known_names: Dict[str, str],
|
||||
client: OpenAI, model: str, debug_filename: str, step: int) -> str:
|
||||
"""Identify a single speaker's name."""
|
||||
# Get this speaker's lines (after 15s)
|
||||
speaker_lines = [(ts, text) for ts, spk, text in dialogue_lines
|
||||
if spk == speaker and parse_timestamp(ts) > 15]
|
||||
|
||||
# Prioritize lines with identifying keywords - Mars mentions first
|
||||
mars_lines = [l for l in speaker_lines if 'mars' in l[1].lower()]
|
||||
other_priority = [l for l in speaker_lines if 'mars' not in l[1].lower() and
|
||||
any(k in l[1].lower() for k in ['surface', 'volcanoes', 'craters', 'my surface', 'up here', 'labyrinth'])]
|
||||
other_lines = [l for l in speaker_lines if l not in mars_lines and l not in other_priority]
|
||||
|
||||
# Combine: Mars lines first, then other priority, then others, max 8 lines
|
||||
selected_lines = (mars_lines + other_priority + other_lines)[:8]
|
||||
speaker_sample = '\n'.join([f'{ts} "{text}"' for ts, text in selected_lines])
|
||||
|
||||
# Build list of who we already know
|
||||
known_info = "Known: " + ", ".join([f"Speaker {s} = {n}" for s, n in known_names.items()]) if known_names else ""
|
||||
|
||||
prompt = f"""Little Malabar dialogue. {known_info}
|
||||
|
||||
CONTEXT:
|
||||
- Malabar is the main character (a boy) who explores space
|
||||
- Other speakers are usually celestial bodies (Moon, Earth, Mars, Sun, etc.)
|
||||
- BUT speakers can also be other entities: volcanoes, the sea, a comet, a star, etc.
|
||||
- Look at what the speaker talks about to identify them
|
||||
|
||||
IDENTIFICATION GUIDELINES:
|
||||
- Speaker mentions "my surface" + warm/shaking → likely Earth
|
||||
- Speaker mentions being "up here" with no ocean → likely Moon
|
||||
- Speaker mentions "us volcanoes on Mars" → could be Mars OR Volcanoes
|
||||
- Speaker mentions the sea/ocean/waves → could be Sea/Ocean
|
||||
- Speaker suggests going TO a place → likely describing that place from outside
|
||||
- Use your judgment based on context and content
|
||||
|
||||
Speaker {speaker}'s lines:
|
||||
{speaker_sample}
|
||||
|
||||
Who is Speaker {speaker}? Reply with a single descriptive name (e.g., "Moon", "Earth", "Mars", "Volcanoes", "Sea", "Sun", "Comet", "Star"):"""
|
||||
|
||||
return ask_llm_for_name(prompt, client, model, debug_filename, step)
|
||||
|
||||
|
||||
def process_lines_file(input_path: Path, client: OpenAI, model: str) -> Path:
|
||||
"""Process a single lines file using multi-step approach."""
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Processing: {input_path.name}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
debug_filename = input_path.stem
|
||||
|
||||
# Read lines file
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
lines_text = f.read()
|
||||
|
||||
# Parse lines
|
||||
lines = parse_lines(lines_text)
|
||||
print(f" Parsed {len(lines)} lines")
|
||||
|
||||
if not lines:
|
||||
print(" No valid lines found!")
|
||||
return None
|
||||
|
||||
# Classify speakers
|
||||
song_speakers, dialogue_speakers = classify_speakers_by_time(lines)
|
||||
print(f" Dialogue speakers: {', '.join(sorted(dialogue_speakers))}")
|
||||
|
||||
# Build mapping starting with song speakers
|
||||
final_mapping = {}
|
||||
for speaker in song_speakers:
|
||||
final_mapping[speaker] = "Song"
|
||||
|
||||
if not dialogue_speakers:
|
||||
print(f" All lines are within first 15 seconds (opening song)")
|
||||
else:
|
||||
# Separate regular speakers from unknown/merged speakers (like "?")
|
||||
regular_speakers = [s for s in dialogue_speakers if s.isalpha()]
|
||||
unknown_speakers = [s for s in dialogue_speakers if not s.isalpha()]
|
||||
|
||||
# Step 1: Identify Malabar (from regular speakers only)
|
||||
print(f" Step 1: Identifying Malabar...")
|
||||
malabar_speaker = identify_malabar(lines, client, model, debug_filename)
|
||||
|
||||
if malabar_speaker:
|
||||
final_mapping[malabar_speaker] = "Malabar"
|
||||
print(f" Identified Speaker {malabar_speaker} = Malabar")
|
||||
elif regular_speakers:
|
||||
# Fallback: assume first regular speaker alphabetically is Malabar
|
||||
malabar_speaker = sorted(regular_speakers)[0]
|
||||
final_mapping[malabar_speaker] = "Malabar"
|
||||
print(f" Fallback: Speaker {malabar_speaker} = Malabar")
|
||||
|
||||
# Step 2+: Identify remaining regular speakers one by one
|
||||
remaining = [s for s in regular_speakers if s not in final_mapping]
|
||||
step = 2
|
||||
|
||||
for speaker in remaining:
|
||||
print(f" Step {step}: Identifying Speaker {speaker}...")
|
||||
name = identify_speaker(speaker, lines, final_mapping, client, model, debug_filename, step)
|
||||
|
||||
if name and len(name) > 1:
|
||||
final_mapping[speaker] = name
|
||||
print(f" Identified Speaker {speaker} = {name}")
|
||||
else:
|
||||
final_mapping[speaker] = f"Speaker_{speaker}"
|
||||
print(f" Fallback: Speaker {speaker} = Speaker_{speaker}")
|
||||
|
||||
step += 1
|
||||
|
||||
# Handle unknown speakers (like "?")
|
||||
for speaker in unknown_speakers:
|
||||
print(f" Step {step}: Identifying unknown Speaker {speaker}...")
|
||||
# Try to identify based on content
|
||||
name = identify_speaker(speaker, lines, final_mapping, client, model, debug_filename, step)
|
||||
if name and len(name) > 1 and name.lower() not in ['unknown', 'speaker', 'name']:
|
||||
final_mapping[speaker] = name
|
||||
print(f" Identified Speaker {speaker} = {name}")
|
||||
else:
|
||||
final_mapping[speaker] = "Unknown"
|
||||
print(f" Marked Speaker {speaker} = Unknown")
|
||||
step += 1
|
||||
|
||||
print(f" Final mapping: {final_mapping}")
|
||||
|
||||
# Apply speaker names to output
|
||||
output_text = apply_speaker_names(lines, final_mapping)
|
||||
|
||||
# Save output
|
||||
output_filename = input_path.stem.replace("_lines", "") + "_speakers.txt"
|
||||
output_path = OUTPUT_DIR / output_filename
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(output_text)
|
||||
|
||||
print(f" Saved to: {output_path}")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def apply_speaker_names(lines: List[Tuple[str, str, str]], mapping: Dict[str, str]) -> str:
|
||||
"""Apply speaker names to lines.
|
||||
|
||||
SPECIAL: Lines in first 15 seconds are labeled as "Song" (opening theme).
|
||||
"""
|
||||
result_lines = []
|
||||
|
||||
for timestamp, speaker, text in lines:
|
||||
# Check if this line is in the first 15 seconds
|
||||
if parse_timestamp(timestamp) <= 15:
|
||||
speaker_name = "Song"
|
||||
else:
|
||||
speaker_name = mapping.get(speaker, f"Speaker_{speaker}")
|
||||
result_lines.append(f"{timestamp}({speaker_name}) {text}")
|
||||
|
||||
return '\n'.join(result_lines)
|
||||
|
||||
|
||||
def get_input_files() -> list[Path]:
|
||||
"""Discover all text files in _lines/ folder."""
|
||||
if not INPUT_DIR.exists():
|
||||
return []
|
||||
files = [f for f in INPUT_DIR.iterdir() if f.is_file() and f.suffix == '.txt']
|
||||
return sorted(files)
|
||||
|
||||
|
||||
def main():
|
||||
ensure_dirs()
|
||||
|
||||
# Get LLM config
|
||||
base_url, model = get_llm_config()
|
||||
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=base_url)
|
||||
|
||||
# Discover input files
|
||||
lines_files = get_input_files()
|
||||
|
||||
if not lines_files:
|
||||
print(f"No .txt files found in {INPUT_DIR}/")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(lines_files)} transcript(s) in {INPUT_DIR}/")
|
||||
print(f"Debug info will be saved to {DEBUG_DIR}/")
|
||||
print("")
|
||||
|
||||
# Process all files
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for input_path in lines_files:
|
||||
try:
|
||||
output_path = process_lines_file(input_path, client, model)
|
||||
if output_path:
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
print(f"\n❌ Failed to process {input_path.name}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
fail_count += 1
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*50)
|
||||
print(f"Step 3 Complete: {success_count} succeeded, {fail_count} failed")
|
||||
print(f"Debug files saved to: {DEBUG_DIR}/")
|
||||
print("="*50)
|
||||
|
||||
if fail_count > 0:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user