Simplify the pipeline, merge the opening song
This commit is contained in:
@@ -32,6 +32,7 @@ from openai import OpenAI
|
||||
INPUT_DIR = Path("_lines")
|
||||
OUTPUT_DIR = Path("_speakers")
|
||||
DEBUG_DIR = Path("_speakers_debug")
|
||||
PROGRESS_FILE = Path(".step3_progress.json")
|
||||
|
||||
# Examples of good speaker names (for reference, not a restricted list)
|
||||
NAME_EXAMPLES = ["Malabar", "Sun", "Jupiter", "Kangaroo", "Mole"]
|
||||
@@ -59,6 +60,20 @@ def ensure_dirs():
|
||||
DEBUG_DIR.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def load_progress() -> dict:
|
||||
"""Load progress tracking."""
|
||||
if PROGRESS_FILE.exists():
|
||||
with open(PROGRESS_FILE, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
return {}
|
||||
|
||||
|
||||
def save_progress(progress: dict):
|
||||
"""Save progress tracking."""
|
||||
with open(PROGRESS_FILE, 'w', encoding='utf-8') as f:
|
||||
json.dump(progress, f, indent=2)
|
||||
|
||||
|
||||
def get_llm_config() -> Tuple[str, str]:
|
||||
"""Get LLM configuration from environment."""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
@@ -83,7 +98,9 @@ def get_llm_config() -> Tuple[str, str]:
|
||||
|
||||
def parse_lines(lines_text: str) -> List[Tuple[str, str, str]]:
|
||||
"""Parse formatted lines. Returns list of (timestamp, speaker_label, text)."""
|
||||
pattern = r'^(\[\d{2}:\d{2}\])\(Speaker ([A-Z?])\) (.+)$'
|
||||
# Pattern to match both (Speaker X) and (Song) formats
|
||||
# Speaker "Song" is reserved for the opening song
|
||||
pattern = r'^(\[\d{2}:\d{2}\])\((Speaker [A-Z?]|Song)\) (.+)$'
|
||||
result = []
|
||||
|
||||
for line in lines_text.strip().split('\n'):
|
||||
@@ -94,60 +111,19 @@ def parse_lines(lines_text: str) -> List[Tuple[str, str, str]]:
|
||||
match = re.match(pattern, line)
|
||||
if match:
|
||||
timestamp = match.group(1)
|
||||
speaker = match.group(2)
|
||||
speaker_raw = match.group(2)
|
||||
text = match.group(3)
|
||||
# Normalize: "Speaker X" -> "X", "Song" -> "Song"
|
||||
if speaker_raw == "Song":
|
||||
speaker = "Song"
|
||||
else:
|
||||
# Extract letter from "Speaker X"
|
||||
speaker = speaker_raw.replace("Speaker ", "")
|
||||
result.append((timestamp, speaker, text))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_timestamp(ts: str) -> int:
|
||||
"""Parse [mm:ss] timestamp to total seconds."""
|
||||
match = re.match(r'\[(\d{2}):(\d{2})\]', ts)
|
||||
if match:
|
||||
minutes = int(match.group(1))
|
||||
seconds = int(match.group(2))
|
||||
return minutes * 60 + seconds
|
||||
return 0
|
||||
|
||||
|
||||
def classify_speakers_by_time(lines: List[Tuple[str, str, str]]) -> Tuple[set, set]:
|
||||
"""Classify speakers based on when they appear."""
|
||||
all_speakers = set(speaker for _, speaker, _ in lines)
|
||||
|
||||
song_speakers = set()
|
||||
dialogue_speakers = set()
|
||||
|
||||
for speaker in all_speakers:
|
||||
has_lines_after_15 = any(
|
||||
parse_timestamp(ts) > 15 and spk == speaker
|
||||
for ts, spk, _ in lines
|
||||
)
|
||||
if has_lines_after_15:
|
||||
dialogue_speakers.add(speaker)
|
||||
else:
|
||||
has_lines_in_first_15 = any(
|
||||
parse_timestamp(ts) <= 15 and spk == speaker
|
||||
for ts, spk, _ in lines
|
||||
)
|
||||
if has_lines_in_first_15:
|
||||
song_speakers.add(speaker)
|
||||
|
||||
return song_speakers, dialogue_speakers
|
||||
|
||||
|
||||
def format_dialogue_with_names(lines: List[Tuple[str, str, str]], speaker_names: Dict[str, str]) -> str:
|
||||
"""Format dialogue lines with known speaker names."""
|
||||
result_lines = []
|
||||
for timestamp, speaker, text in lines:
|
||||
# Skip lines in first 15 seconds (opening song)
|
||||
if parse_timestamp(timestamp) <= 15:
|
||||
continue
|
||||
name = speaker_names.get(speaker, f"Speaker_{speaker}")
|
||||
result_lines.append(f'{timestamp}({name}) {text}')
|
||||
return '\n'.join(result_lines)
|
||||
|
||||
|
||||
def save_debug(filename: str, request: str, response: str, step: int):
|
||||
"""Save debug info to _speakers_debug folder."""
|
||||
debug_file = DEBUG_DIR / f"{filename}_step{step}.txt"
|
||||
@@ -244,9 +220,9 @@ def ask_llm_for_name(prompt: str, client: OpenAI, model: str, debug_filename: st
|
||||
def identify_malabar(dialogue_lines: List[Tuple[str, str, str]],
|
||||
client: OpenAI, model: str, debug_filename: str) -> Optional[str]:
|
||||
"""Identify which speaker is Malabar."""
|
||||
# Only consider single-letter speakers (exclude "?" and other special markers)
|
||||
# Only consider single-letter speakers (exclude "?", "Song", and other special markers)
|
||||
speakers = sorted(set(speaker for _, speaker, _ in dialogue_lines
|
||||
if parse_timestamp(_) > 15 and len(speaker) == 1 and speaker.isalpha()))
|
||||
if len(speaker) == 1 and speaker.isalpha()))
|
||||
|
||||
if not speakers:
|
||||
return None
|
||||
@@ -255,7 +231,7 @@ def identify_malabar(dialogue_lines: List[Tuple[str, str, str]],
|
||||
samples = []
|
||||
for speaker in speakers:
|
||||
lines = [(ts, text) for ts, spk, text in dialogue_lines
|
||||
if spk == speaker and parse_timestamp(ts) > 15][:3]
|
||||
if spk == speaker][:3]
|
||||
for ts, text in lines:
|
||||
samples.append(f'{speaker}: "{text}"')
|
||||
|
||||
@@ -282,9 +258,9 @@ def identify_speaker(speaker: str,
|
||||
known_names: Dict[str, str],
|
||||
client: OpenAI, model: str, debug_filename: str, step: int) -> str:
|
||||
"""Identify a single speaker's name."""
|
||||
# Get this speaker's lines (after 15s)
|
||||
# Get this speaker's lines
|
||||
speaker_lines = [(ts, text) for ts, spk, text in dialogue_lines
|
||||
if spk == speaker and parse_timestamp(ts) > 15]
|
||||
if spk == speaker]
|
||||
|
||||
# Prioritize lines with identifying keywords - Mars mentions first
|
||||
mars_lines = [l for l in speaker_lines if 'mars' in l[1].lower()]
|
||||
@@ -323,8 +299,18 @@ Who is Speaker {speaker}? Reply with a single descriptive name (e.g., "Moon", "E
|
||||
return ask_llm_for_name(prompt, client, model, debug_filename, step)
|
||||
|
||||
|
||||
def process_lines_file(input_path: Path, client: OpenAI, model: str) -> Path:
|
||||
def process_lines_file(input_path: Path, client: OpenAI, model: str, force: bool = False) -> Path:
|
||||
"""Process a single lines file using multi-step approach."""
|
||||
progress = load_progress()
|
||||
filename = input_path.name
|
||||
|
||||
# Check if already processed
|
||||
if not force and filename in progress and progress[filename].get("status") == "completed":
|
||||
output_path = Path(progress[filename]["output_file"])
|
||||
if output_path.exists():
|
||||
print(f"Skipping {filename} (already processed)")
|
||||
return output_path
|
||||
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Processing: {input_path.name}")
|
||||
print(f"{'='*50}")
|
||||
@@ -343,21 +329,21 @@ def process_lines_file(input_path: Path, client: OpenAI, model: str) -> Path:
|
||||
print(" No valid lines found!")
|
||||
return None
|
||||
|
||||
# Classify speakers
|
||||
song_speakers, dialogue_speakers = classify_speakers_by_time(lines)
|
||||
print(f" Dialogue speakers: {', '.join(sorted(dialogue_speakers))}")
|
||||
# Get unique speakers (excluding "Song" - already known)
|
||||
all_speakers = set(speaker for _, speaker, _ in lines)
|
||||
speakers_to_identify = [s for s in all_speakers if s != "Song"]
|
||||
|
||||
# Build mapping starting with song speakers
|
||||
print(f" Speakers to identify: {', '.join(sorted(speakers_to_identify))}")
|
||||
|
||||
# Build mapping
|
||||
final_mapping = {}
|
||||
for speaker in song_speakers:
|
||||
final_mapping[speaker] = "Song"
|
||||
|
||||
if not dialogue_speakers:
|
||||
print(f" All lines are within first 15 seconds (opening song)")
|
||||
if not speakers_to_identify:
|
||||
print(f" No speakers to identify (only Song present)")
|
||||
else:
|
||||
# Separate regular speakers from unknown/merged speakers (like "?")
|
||||
regular_speakers = [s for s in dialogue_speakers if s.isalpha()]
|
||||
unknown_speakers = [s for s in dialogue_speakers if not s.isalpha()]
|
||||
regular_speakers = [s for s in speakers_to_identify if s.isalpha()]
|
||||
unknown_speakers = [s for s in speakers_to_identify if not s.isalpha()]
|
||||
|
||||
# Step 1: Identify Malabar (from regular speakers only)
|
||||
print(f" Step 1: Identifying Malabar...")
|
||||
@@ -414,6 +400,14 @@ def process_lines_file(input_path: Path, client: OpenAI, model: str) -> Path:
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(output_text)
|
||||
|
||||
# Update progress
|
||||
progress[filename] = {
|
||||
"status": "completed",
|
||||
"output_file": str(output_path),
|
||||
"speaker_mapping": final_mapping
|
||||
}
|
||||
save_progress(progress)
|
||||
|
||||
print(f" Saved to: {output_path}")
|
||||
|
||||
return output_path
|
||||
@@ -422,13 +416,13 @@ def process_lines_file(input_path: Path, client: OpenAI, model: str) -> Path:
|
||||
def apply_speaker_names(lines: List[Tuple[str, str, str]], mapping: Dict[str, str]) -> str:
|
||||
"""Apply speaker names to lines.
|
||||
|
||||
SPECIAL: Lines in first 15 seconds are labeled as "Song" (opening theme).
|
||||
SPECIAL: "Song" speaker is passed through unchanged (already labeled in Step 2).
|
||||
"""
|
||||
result_lines = []
|
||||
|
||||
for timestamp, speaker, text in lines:
|
||||
# Check if this line is in the first 15 seconds
|
||||
if parse_timestamp(timestamp) <= 15:
|
||||
# "Song" speaker is already correctly labeled - pass through unchanged
|
||||
if speaker == "Song":
|
||||
speaker_name = "Song"
|
||||
else:
|
||||
speaker_name = mapping.get(speaker, f"Speaker_{speaker}")
|
||||
@@ -448,6 +442,9 @@ def get_input_files() -> list[Path]:
|
||||
def main():
|
||||
ensure_dirs()
|
||||
|
||||
# Check for force flag
|
||||
force = "--force" in sys.argv or "-f" in sys.argv
|
||||
|
||||
# Get LLM config
|
||||
base_url, model = get_llm_config()
|
||||
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=base_url)
|
||||
@@ -460,6 +457,8 @@ def main():
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(lines_files)} transcript(s) in {INPUT_DIR}/")
|
||||
if force:
|
||||
print("Force mode: ON (reprocessing all files)")
|
||||
print(f"Debug info will be saved to {DEBUG_DIR}/")
|
||||
print("")
|
||||
|
||||
@@ -469,10 +468,13 @@ def main():
|
||||
|
||||
for input_path in lines_files:
|
||||
try:
|
||||
output_path = process_lines_file(input_path, client, model)
|
||||
output_path = process_lines_file(input_path, client, model, force=force)
|
||||
if output_path:
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
progress = load_progress()
|
||||
progress[input_path.name] = {"status": "error", "error": str(e)}
|
||||
save_progress(progress)
|
||||
print(f"\n❌ Failed to process {input_path.name}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
Reference in New Issue
Block a user