508 lines
18 KiB
Python
508 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Step 3: Use LLM to infer speaker names from transcript context.
|
|
|
|
Input: Line-formatted files in "_lines/" folder
|
|
Output: Files with inferred speaker names in "_speakers/" folder
|
|
|
|
This version uses a multi-step approach:
|
|
1. First identify Malabar (he's always present)
|
|
2. Then identify each remaining speaker one by one
|
|
3. Each step saves debug info to _speakers_debug/
|
|
|
|
Usage:
|
|
uv run step3_infer_speakers.py
|
|
|
|
Environment Variables:
|
|
OPENAI_API_KEY - Required (can be OpenAI, Kimi, or GLM key)
|
|
OPENAI_BASE_URL - Optional (for Kimi/GLM APIs)
|
|
LLM_MODEL - Optional (e.g., "glm-4.5-air", "kimi-latest")
|
|
"""
|
|
|
|
import os
|
|
import re
|
|
import sys
|
|
import json
|
|
from pathlib import Path
|
|
from typing import List, Dict, Tuple, Optional
|
|
from openai import OpenAI
|
|
|
|
# ============== Configuration ==============
|
|
|
|
INPUT_DIR = Path("_lines")
|
|
OUTPUT_DIR = Path("_speakers")
|
|
DEBUG_DIR = Path("_speakers_debug")
|
|
PROGRESS_FILE = Path(".step3_progress.json")
|
|
|
|
# Examples of good speaker names (for reference, not a restricted list)
|
|
NAME_EXAMPLES = ["Malabar", "Sun", "Jupiter", "Kangaroo", "Mole"]
|
|
|
|
# Default configurations for different providers
|
|
DEFAULT_CONFIGS = {
|
|
"openai": {
|
|
"base_url": None,
|
|
"model": "gpt-4o-mini"
|
|
},
|
|
"moonshot": {
|
|
"base_url": "https://api.moonshot.cn/v1",
|
|
"model": "kimi-latest"
|
|
},
|
|
"bigmodel": { # Zhipu AI (GLM)
|
|
"base_url": "https://open.bigmodel.cn/api/paas/v4",
|
|
"model": "glm-4.5-air"
|
|
}
|
|
}
|
|
|
|
|
|
def ensure_dirs():
|
|
"""Ensure output directories exist."""
|
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
|
DEBUG_DIR.mkdir(exist_ok=True)
|
|
|
|
|
|
def load_progress() -> dict:
|
|
"""Load progress tracking."""
|
|
if PROGRESS_FILE.exists():
|
|
with open(PROGRESS_FILE, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
return {}
|
|
|
|
|
|
def save_progress(progress: dict):
|
|
"""Save progress tracking."""
|
|
with open(PROGRESS_FILE, 'w', encoding='utf-8') as f:
|
|
json.dump(progress, f, indent=2)
|
|
|
|
|
|
def get_llm_config() -> Tuple[str, str]:
|
|
"""Get LLM configuration from environment."""
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
if not api_key:
|
|
raise ValueError("OPENAI_API_KEY environment variable is required")
|
|
|
|
base_url = os.getenv("OPENAI_BASE_URL")
|
|
model = os.getenv("LLM_MODEL")
|
|
|
|
if base_url:
|
|
if model:
|
|
return base_url, model
|
|
if "bigmodel" in base_url:
|
|
return base_url, DEFAULT_CONFIGS["bigmodel"]["model"]
|
|
elif "moonshot" in base_url or "kimi" in base_url:
|
|
return base_url, DEFAULT_CONFIGS["moonshot"]["model"]
|
|
else:
|
|
return base_url, DEFAULT_CONFIGS["openai"]["model"]
|
|
else:
|
|
return None, model or DEFAULT_CONFIGS["openai"]["model"]
|
|
|
|
|
|
def parse_lines(lines_text: str) -> List[Tuple[str, str, str]]:
|
|
"""Parse formatted lines. Returns list of (timestamp, speaker_label, text)."""
|
|
# Pattern to match both (Speaker X) and (Song) formats
|
|
# Speaker "Song" is reserved for the opening song
|
|
pattern = r'^(\[\d{2}:\d{2}\])\((Speaker [A-Z?]|Song)\) (.+)$'
|
|
result = []
|
|
|
|
for line in lines_text.strip().split('\n'):
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
match = re.match(pattern, line)
|
|
if match:
|
|
timestamp = match.group(1)
|
|
speaker_raw = match.group(2)
|
|
text = match.group(3)
|
|
# Normalize: "Speaker X" -> "X", "Song" -> "Song"
|
|
if speaker_raw == "Song":
|
|
speaker = "Song"
|
|
else:
|
|
# Extract letter from "Speaker X"
|
|
speaker = speaker_raw.replace("Speaker ", "")
|
|
result.append((timestamp, speaker, text))
|
|
|
|
return result
|
|
|
|
|
|
def save_debug(filename: str, request: str, response: str, step: int, model: str = "", endpoint: str = ""):
|
|
"""Save debug info to _speakers_debug folder."""
|
|
debug_file = DEBUG_DIR / f"{filename}_step{step}.txt"
|
|
with open(debug_file, 'w', encoding='utf-8') as f:
|
|
f.write("=" * 60 + "\n")
|
|
f.write("DEBUG INFO:\n")
|
|
f.write("=" * 60 + "\n")
|
|
if model:
|
|
f.write(f"Model: {model}\n")
|
|
if endpoint:
|
|
f.write(f"Endpoint: {endpoint}\n")
|
|
f.write("\n")
|
|
f.write("=" * 60 + "\n")
|
|
f.write("REQUEST:\n")
|
|
f.write("=" * 60 + "\n\n")
|
|
f.write(request)
|
|
f.write("\n\n")
|
|
f.write("=" * 60 + "\n")
|
|
f.write("RESPONSE:\n")
|
|
f.write("=" * 60 + "\n\n")
|
|
f.write(response)
|
|
|
|
|
|
def ask_llm_for_name(prompt: str, client: OpenAI, model: str, debug_filename: str, step: int, exclude_names: list = None, base_url: str = "") -> str:
|
|
"""Ask LLM for a single name. Returns the name or raises exception if invalid."""
|
|
# Valid speaker names
|
|
valid_names = ['Malabar', 'Moon', 'Earth', 'Mars', 'Sun', 'Jupiter', 'Saturn', 'Venus',
|
|
'Mercury', 'Neptune', 'Uranus', 'Pluto', 'Galaxy', 'Star', 'Kangaroo',
|
|
'Giraffe', 'Volcano', 'Volcanoes', 'Sea', 'Ocean', 'Wave', 'Comet',
|
|
'Asteroid', 'Meteor', 'Nebula', 'Black Hole', 'Alien', 'Robot', 'Scientist']
|
|
|
|
# Filter out excluded names
|
|
if exclude_names:
|
|
valid_names = [n for n in valid_names if n not in exclude_names]
|
|
|
|
try:
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=[
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
temperature=0.0,
|
|
max_tokens=20, # Short response expected
|
|
extra_body={"thinking": {"type": "disabled"}} # Disable thinking
|
|
)
|
|
|
|
message = response.choices[0].message
|
|
raw_result = message.content.strip() if message.content else ""
|
|
|
|
# Save debug info
|
|
save_debug(debug_filename, prompt, f"RAW: {raw_result}", step, model=model, endpoint=base_url or "OpenAI default")
|
|
|
|
# Simple validation: result should be one of the valid names
|
|
if raw_result in valid_names:
|
|
return raw_result
|
|
|
|
# Check case-insensitive match
|
|
for name in valid_names:
|
|
if raw_result.lower() == name.lower():
|
|
return name
|
|
|
|
raise ValueError(f"Invalid response from LLM: expected one of {valid_names}, got '{raw_result}'")
|
|
|
|
except Exception as e:
|
|
save_debug(debug_filename, prompt, f"ERROR: {e}", step, model=model, endpoint=base_url or "OpenAI default")
|
|
raise # Re-raise the exception
|
|
|
|
|
|
def identify_malabar(dialogue_lines: List[Tuple[str, str, str]],
|
|
client: OpenAI, model: str, debug_filename: str, base_url: str = "") -> Optional[str]:
|
|
"""Identify which speaker is Malabar."""
|
|
# Only consider single-letter speakers (exclude "?", "Song", and other special markers)
|
|
speakers = sorted(set(speaker for _, speaker, _ in dialogue_lines
|
|
if len(speaker) == 1 and speaker.isalpha()))
|
|
|
|
if not speakers:
|
|
return None
|
|
|
|
# Output ALL lines in chronological order (preserving original order)
|
|
samples = []
|
|
for ts, spk, text in dialogue_lines:
|
|
# Skip Song speaker for Malabar identification
|
|
if spk == "Song":
|
|
continue
|
|
# Only include speakers we're trying to identify
|
|
if spk in speakers:
|
|
samples.append(f'{spk}: "{text}"')
|
|
|
|
sample_text = '\n'.join(samples)
|
|
|
|
prompt = f"""Little Malabar dialogue. Malabar is a boy who talks to stars, planets and animals.
|
|
|
|
{sample_text}
|
|
|
|
Which speaker letter is Malabar? Reply with ONLY the letter A, B, or C."""
|
|
|
|
try:
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=[
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
temperature=0.0,
|
|
max_tokens=10, # Short response expected
|
|
extra_body={"thinking": {"type": "disabled"}} # Disable thinking
|
|
)
|
|
|
|
message = response.choices[0].message
|
|
raw_result = message.content.strip() if message.content else ""
|
|
|
|
# Get endpoint for debug info
|
|
endpoint = base_url or (str(client.base_url) if hasattr(client, 'base_url') else "OpenAI default")
|
|
|
|
# Save debug info
|
|
save_debug(debug_filename, prompt, f"RAW: {raw_result}", 1, model=model, endpoint=endpoint)
|
|
|
|
# Simple validation: result should be a single letter in speakers list
|
|
if raw_result and len(raw_result) == 1 and raw_result.upper() in speakers:
|
|
return raw_result.upper()
|
|
|
|
raise ValueError(f"Invalid response from LLM: expected single letter A/B/C, got '{raw_result}'")
|
|
|
|
except Exception as e:
|
|
endpoint = base_url or (str(client.base_url) if hasattr(client, 'base_url') else "OpenAI default")
|
|
save_debug(debug_filename, prompt, f"ERROR: {e}", 1, model=model, endpoint=endpoint)
|
|
raise # Re-raise the exception
|
|
|
|
|
|
def identify_speaker(speaker: str,
|
|
dialogue_lines: List[Tuple[str, str, str]],
|
|
known_names: Dict[str, str],
|
|
client: OpenAI, model: str, debug_filename: str, step: int) -> str:
|
|
"""Identify a single speaker's name."""
|
|
# Build the full dialogue with proper speaker names
|
|
# For known speakers, use their real name
|
|
# For the target speaker, keep as "Speaker X"
|
|
# For unknown speakers, keep as "Speaker X"
|
|
full_dialogue_lines = []
|
|
for ts, spk, text in dialogue_lines:
|
|
if spk == speaker:
|
|
# Target speaker - keep as Speaker X (we're trying to identify them)
|
|
full_dialogue_lines.append(f'Speaker {spk}: "{text}"')
|
|
elif spk in known_names:
|
|
# Known speaker - use real name
|
|
full_dialogue_lines.append(f'{known_names[spk]}: "{text}"')
|
|
elif spk == "Song":
|
|
full_dialogue_lines.append(f'Song: "{text}"')
|
|
else:
|
|
# Unknown speaker - keep as Speaker X
|
|
full_dialogue_lines.append(f'Speaker {spk}: "{text}"')
|
|
|
|
full_dialogue = '\n'.join(full_dialogue_lines)
|
|
|
|
# Build list of who we already know
|
|
known_info = "Known: " + ", ".join([f"Speaker {s} = {n}" for s, n in known_names.items()]) if known_names else ""
|
|
|
|
prompt = f"""Little Malabar dialogue. {known_info}
|
|
|
|
CONTEXT:
|
|
- Malabar is a boy who talks to stars, planets and animals
|
|
- Other speakers are usually celestial bodies (Moon, Earth, Mars, Sun, etc.)
|
|
- BUT speakers can also be other entities: volcanoes, the sea, a comet, a star, etc.
|
|
- Look at what the speaker talks about AND what others say to them to identify them
|
|
|
|
IDENTIFICATION GUIDELINES:
|
|
- Speaker mentions "my surface" + warm/shaking → likely Earth
|
|
- Speaker mentions being "up here" with no ocean → likely Moon
|
|
- Speaker says "us volcanoes on Mars" → this is Volcanoes (not Mars!)
|
|
- Speaker is spoken TO about Mars/volcanoes → could be Mars
|
|
- Speaker mentions the sea/ocean/waves → could be Sea/Ocean
|
|
- Speaker suggests going TO a place → likely describing that place from outside
|
|
- Use your judgment based on context and content
|
|
|
|
FULL DIALOGUE:
|
|
{full_dialogue}
|
|
|
|
Who is Speaker {speaker}? Reply with ONLY the name, nothing else. Examples: Moon, Earth, Mars, Volcanoes, Sea, Sun, Jupiter:"""
|
|
|
|
# Get list of already known names to exclude from extraction
|
|
known_names_list = list(known_names.values()) if known_names else []
|
|
# Get base_url from client for debug info
|
|
base_url = client.base_url if hasattr(client, 'base_url') else ""
|
|
return ask_llm_for_name(prompt, client, model, debug_filename, step, exclude_names=known_names_list, base_url=base_url)
|
|
|
|
|
|
def process_lines_file(input_path: Path, client: OpenAI, model: str, force: bool = False, base_url: str = "") -> Path:
|
|
"""Process a single lines file using multi-step approach."""
|
|
progress = load_progress()
|
|
filename = input_path.name
|
|
|
|
# Check if already processed
|
|
if not force and filename in progress and progress[filename].get("status") == "completed":
|
|
output_path = Path(progress[filename]["output_file"])
|
|
if output_path.exists():
|
|
print(f"Skipping {filename} (already processed)")
|
|
return output_path
|
|
|
|
print(f"\n{'='*50}")
|
|
print(f"Processing: {input_path.name}")
|
|
print(f"{'='*50}")
|
|
|
|
debug_filename = input_path.stem
|
|
|
|
# Read lines file
|
|
with open(input_path, 'r', encoding='utf-8') as f:
|
|
lines_text = f.read()
|
|
|
|
# Parse lines
|
|
lines = parse_lines(lines_text)
|
|
print(f" Parsed {len(lines)} lines")
|
|
|
|
if not lines:
|
|
print(" No valid lines found!")
|
|
return None
|
|
|
|
# Get unique speakers (excluding "Song" - already known)
|
|
all_speakers = set(speaker for _, speaker, _ in lines)
|
|
speakers_to_identify = [s for s in all_speakers if s != "Song"]
|
|
|
|
print(f" Speakers to identify: {', '.join(sorted(speakers_to_identify))}")
|
|
|
|
# Build mapping
|
|
final_mapping = {}
|
|
|
|
if not speakers_to_identify:
|
|
print(f" No speakers to identify (only Song present)")
|
|
else:
|
|
# Separate regular speakers from unknown/merged speakers (like "?")
|
|
regular_speakers = [s for s in speakers_to_identify if s.isalpha()]
|
|
unknown_speakers = [s for s in speakers_to_identify if not s.isalpha()]
|
|
|
|
# Step 1: Identify Malabar (from regular speakers only)
|
|
print(f" Step 1: Identifying Malabar...")
|
|
try:
|
|
malabar_speaker = identify_malabar(lines, client, model, debug_filename, base_url)
|
|
final_mapping[malabar_speaker] = "Malabar"
|
|
print(f" Identified Speaker {malabar_speaker} = Malabar")
|
|
except Exception as e:
|
|
print(f" Error: {e}")
|
|
if regular_speakers:
|
|
# Fallback: assume first regular speaker alphabetically is Malabar
|
|
malabar_speaker = sorted(regular_speakers)[0]
|
|
final_mapping[malabar_speaker] = "Malabar"
|
|
print(f" Fallback: Speaker {malabar_speaker} = Malabar")
|
|
|
|
# Step 2+: Identify remaining regular speakers one by one
|
|
remaining = [s for s in regular_speakers if s not in final_mapping]
|
|
step = 2
|
|
|
|
for speaker in remaining:
|
|
print(f" Step {step}: Identifying Speaker {speaker}...")
|
|
try:
|
|
name = identify_speaker(speaker, lines, final_mapping, client, model, debug_filename, step)
|
|
final_mapping[speaker] = name
|
|
print(f" Identified Speaker {speaker} = {name}")
|
|
except Exception as e:
|
|
print(f" Error: {e}")
|
|
final_mapping[speaker] = f"Speaker_{speaker}"
|
|
print(f" Fallback: Speaker {speaker} = Speaker_{speaker}")
|
|
step += 1
|
|
|
|
# Handle unknown speakers (like "?")
|
|
for speaker in unknown_speakers:
|
|
print(f" Step {step}: Identifying unknown Speaker {speaker}...")
|
|
try:
|
|
name = identify_speaker(speaker, lines, final_mapping, client, model, debug_filename, step)
|
|
final_mapping[speaker] = name
|
|
print(f" Identified Speaker {speaker} = {name}")
|
|
except Exception as e:
|
|
print(f" Error: {e}")
|
|
final_mapping[speaker] = "Unknown"
|
|
print(f" Marked Speaker {speaker} = Unknown")
|
|
step += 1
|
|
|
|
print(f" Final mapping: {final_mapping}")
|
|
|
|
# Apply speaker names to output
|
|
output_text = apply_speaker_names(lines, final_mapping)
|
|
|
|
# Save output
|
|
output_filename = input_path.stem.replace("_lines", "") + "_speakers.txt"
|
|
output_path = OUTPUT_DIR / output_filename
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
f.write(output_text)
|
|
|
|
# Update progress
|
|
progress[filename] = {
|
|
"status": "completed",
|
|
"output_file": str(output_path),
|
|
"speaker_mapping": final_mapping
|
|
}
|
|
save_progress(progress)
|
|
|
|
print(f" Saved to: {output_path}")
|
|
|
|
return output_path
|
|
|
|
|
|
def apply_speaker_names(lines: List[Tuple[str, str, str]], mapping: Dict[str, str]) -> str:
|
|
"""Apply speaker names to lines.
|
|
|
|
SPECIAL: "Song" speaker is passed through unchanged (already labeled in Step 2).
|
|
"""
|
|
result_lines = []
|
|
|
|
for timestamp, speaker, text in lines:
|
|
# "Song" speaker is already correctly labeled - pass through unchanged
|
|
if speaker == "Song":
|
|
speaker_name = "Song"
|
|
else:
|
|
speaker_name = mapping.get(speaker, f"Speaker_{speaker}")
|
|
result_lines.append(f"{timestamp}({speaker_name}) {text}")
|
|
|
|
return '\n'.join(result_lines)
|
|
|
|
|
|
def get_input_files() -> list[Path]:
|
|
"""Discover all text files in _lines/ folder."""
|
|
if not INPUT_DIR.exists():
|
|
return []
|
|
files = [f for f in INPUT_DIR.iterdir() if f.is_file() and f.suffix == '.txt']
|
|
return sorted(files)
|
|
|
|
|
|
def main():
|
|
ensure_dirs()
|
|
|
|
# Check for force flag
|
|
force = "--force" in sys.argv or "-f" in sys.argv
|
|
|
|
# Get LLM config
|
|
base_url, model = get_llm_config()
|
|
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=base_url)
|
|
|
|
print(f"Using model: {model}")
|
|
print(f"Endpoint: {base_url or 'OpenAI default'}")
|
|
|
|
# Discover input files
|
|
lines_files = get_input_files()
|
|
|
|
if not lines_files:
|
|
print(f"No .txt files found in {INPUT_DIR}/")
|
|
sys.exit(1)
|
|
|
|
print(f"Found {len(lines_files)} transcript(s) in {INPUT_DIR}/")
|
|
if force:
|
|
print("Force mode: ON (reprocessing all files)")
|
|
print(f"Debug info will be saved to {DEBUG_DIR}/")
|
|
print("")
|
|
|
|
# Process all files
|
|
success_count = 0
|
|
fail_count = 0
|
|
|
|
for input_path in lines_files:
|
|
try:
|
|
output_path = process_lines_file(input_path, client, model, force=force, base_url=base_url or "")
|
|
if output_path:
|
|
success_count += 1
|
|
except Exception as e:
|
|
progress = load_progress()
|
|
progress[input_path.name] = {"status": "error", "error": str(e)}
|
|
save_progress(progress)
|
|
print(f"\n❌ Failed to process {input_path.name}: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
fail_count += 1
|
|
|
|
# Summary
|
|
print("\n" + "="*50)
|
|
print(f"Step 3 Complete: {success_count} succeeded, {fail_count} failed")
|
|
print(f"Debug files saved to: {DEBUG_DIR}/")
|
|
print("="*50)
|
|
|
|
if fail_count > 0:
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|