diff --git a/step1_transcribe.py b/step1_transcribe.py index a097713..ce7be5b 100644 --- a/step1_transcribe.py +++ b/step1_transcribe.py @@ -11,7 +11,6 @@ Usage: """ import os -import re import sys import json from pathlib import Path @@ -30,162 +29,6 @@ def ensure_dirs(): OUTPUT_DIR.mkdir(exist_ok=True) -def split_words_by_sentences(words: list) -> list: - """ - Split a list of words into sentence segments based on punctuation. - - Args: - words: List of word dictionaries with 'text' key - - Returns: - List of word segments, each representing a sentence - """ - if not words: - return [] - - segments = [] - current_segment = [] - - # Pattern for sentence-ending punctuation (including the punctuation itself) - sentence_end_pattern = re.compile(r'[.!?]+["\')\]]*$') - - for word in words: - current_segment.append(word) - text = word.get("text", "") - - # Check if this word ends with sentence-ending punctuation - if sentence_end_pattern.search(text): - # End of sentence - save this segment - segments.append(current_segment) - current_segment = [] - - # Don't forget any remaining words - if current_segment: - segments.append(current_segment) - - return segments - - -def ends_with_sentence_punctuation(text: str) -> bool: - """Check if text ends with sentence-ending punctuation.""" - text = text.strip() - return bool(re.search(r'[.!?]["\'\)\]]*$', text)) - - -def merge_incomplete_sentences(utterances: list) -> list: - """ - Merge consecutive utterances where the first doesn't end with sentence punctuation. - This handles cases where AssemblyAI splits mid-sentence between speakers. - Uses the first speaker's label for merged utterances. - """ - if not utterances: - return utterances - - result = [] - current = utterances[0].copy() - - for i in range(1, len(utterances)): - next_utt = utterances[i] - current_text = current.get("text", "") - - # If current doesn't end with sentence punctuation, merge with next - if not ends_with_sentence_punctuation(current_text): - # Merge words - current["words"] = current.get("words", []) + next_utt.get("words", []) - # Update text - current["text"] = current_text + " " + next_utt.get("text", "") - # Update end time - current["end"] = next_utt.get("end", current["end"]) - # Keep the first speaker's label (don't change to "?") - # current["speaker"] stays the same - else: - # Current is complete, save it and move to next - result.append(current) - current = next_utt.copy() - - # Don't forget the last one - result.append(current) - - return result - - -def split_utterances_by_pauses(utterances: list, pause_threshold_ms: int = 1500) -> list: - """ - Split long utterances based on pauses between words and sentence boundaries. - - Args: - utterances: List of utterance dictionaries from AssemblyAI - pause_threshold_ms: Minimum gap (in milliseconds) to create a new utterance - - Returns: - List of split utterances - """ - # First, merge consecutive utterances that don't end with sentence punctuation - utterances = merge_incomplete_sentences(utterances) - - result = [] - - for utt in utterances: - words = utt.get("words", []) - if not words: - # No word-level data, keep original - result.append(utt) - continue - - speaker = utt.get("speaker", "?") - current_segment_words = [] - segments = [] - - for i, word in enumerate(words): - if not current_segment_words: - # First word in segment - current_segment_words.append(word) - else: - # Check gap from previous word - prev_word = current_segment_words[-1] - gap = word.get("start", 0) - prev_word.get("end", 0) - - if gap >= pause_threshold_ms: - # Gap is large enough - first split by sentences within current segment - sentence_segments = split_words_by_sentences(current_segment_words) - for seg_words in sentence_segments: - segments.append({ - "speaker": speaker, - "words": seg_words, - "start": seg_words[0]["start"], - "end": seg_words[-1]["end"] - }) - current_segment_words = [word] - else: - # Continue current segment - current_segment_words.append(word) - - # Don't forget the last segment - also split by sentences - if current_segment_words: - sentence_segments = split_words_by_sentences(current_segment_words) - for seg_words in sentence_segments: - segments.append({ - "speaker": speaker, - "words": seg_words, - "start": seg_words[0]["start"], - "end": seg_words[-1]["end"] - }) - - # Convert segments to utterance format - for seg in segments: - text = " ".join(w.get("text", "") for w in seg["words"]).strip() - if text: # Only add non-empty segments - result.append({ - "speaker": seg["speaker"], - "text": text, - "start": seg["start"], - "end": seg["end"], - "words": seg["words"] - }) - - return result - - def load_progress() -> dict: """Load progress tracking.""" if PROGRESS_FILE.exists(): @@ -214,12 +57,11 @@ def transcribe_video(video_path: Path) -> dict: print(f" Uploading {video_path.name}...") # Speaker diarization config - # By default, AssemblyAI detects 1-10 speakers - # If you know the expected number, you can set speakers_expected - # Or set speaker_options for a range + # Lower speaker_sensitivity = more aggressive speaker detection (more speakers) speaker_options = aai.SpeakerOptions( min_speakers=2, - max_speakers=10 # Allow up to 10 speakers + max_speakers=10, # Allow up to 10 speakers + speaker_sensitivity=0.2 # Low value = more sensitive to speaker changes ) config = aai.TranscriptionConfig( @@ -237,35 +79,8 @@ def transcribe_video(video_path: Path) -> dict: print(f" Transcription complete!") - # Convert utterances to dictionaries first - raw_utterances = [] - for utt in transcript.utterances: - raw_utterances.append({ - "speaker": utt.speaker, - "text": utt.text.strip(), - "start": utt.start, - "end": utt.end, - "confidence": utt.confidence if hasattr(utt, 'confidence') else None, - "words": [ - { - "text": w.text, - "start": w.start, - "end": w.end, - "speaker": w.speaker if hasattr(w, 'speaker') else None - } - for w in (utt.words if hasattr(utt, 'words') else []) - ] - }) - - # Split long utterances based on pauses - original_count = len(raw_utterances) - split_utterances = split_utterances_by_pauses(raw_utterances, pause_threshold_ms=1500) - new_count = len(split_utterances) - - if new_count > original_count: - print(f" Split {original_count} utterances into {new_count} (based on 1.5s pauses)") - - # Convert transcript to serializable dictionary + # Convert transcript to serializable dictionary - NO POSTPROCESSING + # Raw AssemblyAI output result = { "id": transcript.id, "status": str(transcript.status), @@ -274,7 +89,25 @@ def transcribe_video(video_path: Path) -> dict: "confidence": transcript.confidence, "audio_duration": transcript.audio_duration, "language_code": transcript.json_response.get("language_code", "unknown"), - "utterances": split_utterances + "utterances": [ + { + "speaker": utt.speaker, + "text": utt.text.strip(), + "start": utt.start, + "end": utt.end, + "confidence": utt.confidence if hasattr(utt, 'confidence') else None, + "words": [ + { + "text": w.text, + "start": w.start, + "end": w.end, + "speaker": w.speaker if hasattr(w, 'speaker') else None + } + for w in (utt.words if hasattr(utt, 'words') else []) + ] + } + for utt in transcript.utterances + ] } return result diff --git a/step2_format.py b/step2_format.py index 2a3847f..f6a7da7 100644 --- a/step2_format.py +++ b/step2_format.py @@ -16,7 +16,12 @@ import sys import json import re from pathlib import Path -from typing import List, Dict, Any +from typing import List, Dict, Any, Tuple + +# ============== Configuration ============== + +# Split utterances on pauses longer than this (milliseconds) +PAUSE_THRESHOLD_MS = 1500 # ============== Configuration ============== @@ -51,6 +56,89 @@ def ensure_dirs(): OUTPUT_DIR.mkdir(exist_ok=True) +def split_words_by_sentences(words: list) -> list: + """Split words into sentence segments based on punctuation.""" + if not words: + return [] + + segments = [] + current_segment = [] + sentence_end_pattern = re.compile(r'[.!?]+["\')\]]*$') + + for word in words: + current_segment.append(word) + text = word.get("text", "") + if sentence_end_pattern.search(text): + segments.append(current_segment) + current_segment = [] + + if current_segment: + segments.append(current_segment) + + return segments + + +def split_utterances_by_pauses(utterances: list, pause_threshold_ms: int = 1500) -> list: + """Split long utterances based on pauses between words and sentence boundaries.""" + result = [] + + for utt in utterances: + words = utt.get("words", []) + if not words: + result.append(utt) + continue + + speaker = utt.get("speaker", "?") + current_segment_words = [] + segments = [] + + for i, word in enumerate(words): + if not current_segment_words: + current_segment_words.append(word) + else: + prev_word = current_segment_words[-1] + gap = word.get("start", 0) - prev_word.get("end", 0) + + if gap >= pause_threshold_ms: + # Gap is large enough - split by sentences within current segment + sentence_segments = split_words_by_sentences(current_segment_words) + for seg_words in sentence_segments: + segments.append({ + "speaker": speaker, + "words": seg_words, + "start": seg_words[0]["start"], + "end": seg_words[-1]["end"] + }) + current_segment_words = [word] + else: + current_segment_words.append(word) + + # Process final segment + if current_segment_words: + sentence_segments = split_words_by_sentences(current_segment_words) + for seg_words in sentence_segments: + segments.append({ + "speaker": speaker, + "words": seg_words, + "start": seg_words[0]["start"], + "end": seg_words[-1]["end"] + }) + + # Convert segments to utterance format + for seg in segments: + text = " ".join(w.get("text", "") for w in seg["words"]).strip() + if text: + result.append({ + "speaker": seg["speaker"], + "text": text, + "start": seg["start"], + "end": seg["end"], + "words": seg["words"] + }) + + return result + + def format_timestamp(ms: int) -> str: """Format milliseconds as [mm:ss].""" seconds = ms // 1000 @@ -108,6 +196,73 @@ def merge_utterances(utterances: List[Dict[str, Any]]) -> List[Dict[str, Any]]: return merged +def extract_opening_song_title(utterances: List[Dict[str, Any]]) -> Tuple[str, str, str, List[Dict[str, Any]]]: + """ + Extract title from opening song (lines within first 15 seconds). + Returns (title, song_speaker, joined_song_lyrics, remaining_utterances). + + The title is the text after 'Malabar' in the opening song lyrics. + All opening song lyrics (except title) are joined into one string. + """ + OPENING_SONG_THRESHOLD_MS = 15000 # 15 seconds + + # Separate opening song utterances (within first 15s) from the rest + opening_song = [] + remaining = [] + + for utt in utterances: + if utt.get("start", 0) < OPENING_SONG_THRESHOLD_MS: + opening_song.append(utt) + else: + remaining.append(utt) + + if not opening_song: + return "", "", "", utterances + + # Find the utterance containing "Malabar" + malabar_idx = -1 + title = "" + song_speaker = opening_song[0].get("speaker", "A") if opening_song else "A" + title_utterance_idx = -1 # The utterance that contains the title (to exclude from song) + + for i, utt in enumerate(opening_song): + text = utt.get("text", "") + if "Malabar" in text or "malabar" in text.lower(): + malabar_idx = i + song_speaker = utt.get("speaker", song_speaker) + # Extract title: text after "Malabar" (and any punctuation/space) + match = re.search(r'Malabar[\s,]*(.+)', text, re.IGNORECASE) + if match: + title = match.group(1).strip() + # Remove trailing punctuation from title + title = re.sub(r'[.!?]+$', '', title).strip() + title_utterance_idx = i + # Remove title part from this utterance for song lyrics + utt["text"] = re.sub(r'Malabar[\s,]*.+$', 'Malabar', text, flags=re.IGNORECASE).strip() + break + + # If title not in same utterance as Malabar, check next utterance(s) + if not title and malabar_idx >= 0: + for j in range(malabar_idx + 1, len(opening_song)): + next_text = opening_song[j].get("text", "").strip() + if next_text: + title = re.sub(r'[.!?]+$', '', next_text).strip() + title_utterance_idx = j + break + + # Join all opening song lyrics except the title utterance + song_lines = [] + for i, utt in enumerate(opening_song): + if i != title_utterance_idx: + text = utt.get("text", "").strip() + if text: + song_lines.append(text) + + joined_song = " ".join(song_lines) + + return title, song_speaker, joined_song, remaining + + def format_lines(transcript_data: Dict[str, Any]) -> str: """ Format transcript utterances into lines. @@ -118,12 +273,32 @@ def format_lines(transcript_data: Dict[str, Any]) -> str: if not utterances: return "" + # Split long utterances based on pauses and sentence boundaries + utterances = split_utterances_by_pauses(utterances, PAUSE_THRESHOLD_MS) + + # Extract title from opening song (first 15 seconds) and get joined song lyrics + title, song_speaker, joined_song, utterances = extract_opening_song_title(utterances) + # Merge non-word utterances merged = merge_utterances(utterances) # Format lines lines = [] + + # Add title as first line if found (use "Song" as speaker) + if title: + lines.append(f"[00:00](Song) {title}") + + # Add joined opening song as second line if exists (use "Song" as speaker) + if joined_song: + lines.append(f"[00:01](Song) {joined_song}") + + # Format remaining lines (skip those within first 15s as they're in the joined song) for utt in merged: + # Skip utterances within opening song window (they're already included in joined_song) + if utt.get("start", 0) < 15000: + continue + text = utt.get("text", "").strip() # Skip standalone non-words unless they're at the end @@ -155,10 +330,10 @@ def process_transcript(input_path: Path) -> Path: with open(input_path, 'r', encoding='utf-8') as f: transcript_data = json.load(f) - utterance_count = len(transcript_data.get("utterances", [])) - print(f" Loaded {utterance_count} utterances") + raw_count = len(transcript_data.get("utterances", [])) + print(f" Loaded {raw_count} raw utterances") - # Format lines + # Format lines (includes splitting by pauses) formatted_text = format_lines(transcript_data) # Save output diff --git a/step3_infer_speakers.py b/step3_infer_speakers.py index 412b60d..64f31ab 100644 --- a/step3_infer_speakers.py +++ b/step3_infer_speakers.py @@ -32,6 +32,7 @@ from openai import OpenAI INPUT_DIR = Path("_lines") OUTPUT_DIR = Path("_speakers") DEBUG_DIR = Path("_speakers_debug") +PROGRESS_FILE = Path(".step3_progress.json") # Examples of good speaker names (for reference, not a restricted list) NAME_EXAMPLES = ["Malabar", "Sun", "Jupiter", "Kangaroo", "Mole"] @@ -59,6 +60,20 @@ def ensure_dirs(): DEBUG_DIR.mkdir(exist_ok=True) +def load_progress() -> dict: + """Load progress tracking.""" + if PROGRESS_FILE.exists(): + with open(PROGRESS_FILE, 'r', encoding='utf-8') as f: + return json.load(f) + return {} + + +def save_progress(progress: dict): + """Save progress tracking.""" + with open(PROGRESS_FILE, 'w', encoding='utf-8') as f: + json.dump(progress, f, indent=2) + + def get_llm_config() -> Tuple[str, str]: """Get LLM configuration from environment.""" api_key = os.getenv("OPENAI_API_KEY") @@ -83,7 +98,9 @@ def get_llm_config() -> Tuple[str, str]: def parse_lines(lines_text: str) -> List[Tuple[str, str, str]]: """Parse formatted lines. Returns list of (timestamp, speaker_label, text).""" - pattern = r'^(\[\d{2}:\d{2}\])\(Speaker ([A-Z?])\) (.+)$' + # Pattern to match both (Speaker X) and (Song) formats + # Speaker "Song" is reserved for the opening song + pattern = r'^(\[\d{2}:\d{2}\])\((Speaker [A-Z?]|Song)\) (.+)$' result = [] for line in lines_text.strip().split('\n'): @@ -94,60 +111,19 @@ def parse_lines(lines_text: str) -> List[Tuple[str, str, str]]: match = re.match(pattern, line) if match: timestamp = match.group(1) - speaker = match.group(2) + speaker_raw = match.group(2) text = match.group(3) + # Normalize: "Speaker X" -> "X", "Song" -> "Song" + if speaker_raw == "Song": + speaker = "Song" + else: + # Extract letter from "Speaker X" + speaker = speaker_raw.replace("Speaker ", "") result.append((timestamp, speaker, text)) return result -def parse_timestamp(ts: str) -> int: - """Parse [mm:ss] timestamp to total seconds.""" - match = re.match(r'\[(\d{2}):(\d{2})\]', ts) - if match: - minutes = int(match.group(1)) - seconds = int(match.group(2)) - return minutes * 60 + seconds - return 0 - - -def classify_speakers_by_time(lines: List[Tuple[str, str, str]]) -> Tuple[set, set]: - """Classify speakers based on when they appear.""" - all_speakers = set(speaker for _, speaker, _ in lines) - - song_speakers = set() - dialogue_speakers = set() - - for speaker in all_speakers: - has_lines_after_15 = any( - parse_timestamp(ts) > 15 and spk == speaker - for ts, spk, _ in lines - ) - if has_lines_after_15: - dialogue_speakers.add(speaker) - else: - has_lines_in_first_15 = any( - parse_timestamp(ts) <= 15 and spk == speaker - for ts, spk, _ in lines - ) - if has_lines_in_first_15: - song_speakers.add(speaker) - - return song_speakers, dialogue_speakers - - -def format_dialogue_with_names(lines: List[Tuple[str, str, str]], speaker_names: Dict[str, str]) -> str: - """Format dialogue lines with known speaker names.""" - result_lines = [] - for timestamp, speaker, text in lines: - # Skip lines in first 15 seconds (opening song) - if parse_timestamp(timestamp) <= 15: - continue - name = speaker_names.get(speaker, f"Speaker_{speaker}") - result_lines.append(f'{timestamp}({name}) {text}') - return '\n'.join(result_lines) - - def save_debug(filename: str, request: str, response: str, step: int): """Save debug info to _speakers_debug folder.""" debug_file = DEBUG_DIR / f"{filename}_step{step}.txt" @@ -244,9 +220,9 @@ def ask_llm_for_name(prompt: str, client: OpenAI, model: str, debug_filename: st def identify_malabar(dialogue_lines: List[Tuple[str, str, str]], client: OpenAI, model: str, debug_filename: str) -> Optional[str]: """Identify which speaker is Malabar.""" - # Only consider single-letter speakers (exclude "?" and other special markers) + # Only consider single-letter speakers (exclude "?", "Song", and other special markers) speakers = sorted(set(speaker for _, speaker, _ in dialogue_lines - if parse_timestamp(_) > 15 and len(speaker) == 1 and speaker.isalpha())) + if len(speaker) == 1 and speaker.isalpha())) if not speakers: return None @@ -255,7 +231,7 @@ def identify_malabar(dialogue_lines: List[Tuple[str, str, str]], samples = [] for speaker in speakers: lines = [(ts, text) for ts, spk, text in dialogue_lines - if spk == speaker and parse_timestamp(ts) > 15][:3] + if spk == speaker][:3] for ts, text in lines: samples.append(f'{speaker}: "{text}"') @@ -282,9 +258,9 @@ def identify_speaker(speaker: str, known_names: Dict[str, str], client: OpenAI, model: str, debug_filename: str, step: int) -> str: """Identify a single speaker's name.""" - # Get this speaker's lines (after 15s) + # Get this speaker's lines speaker_lines = [(ts, text) for ts, spk, text in dialogue_lines - if spk == speaker and parse_timestamp(ts) > 15] + if spk == speaker] # Prioritize lines with identifying keywords - Mars mentions first mars_lines = [l for l in speaker_lines if 'mars' in l[1].lower()] @@ -323,8 +299,18 @@ Who is Speaker {speaker}? Reply with a single descriptive name (e.g., "Moon", "E return ask_llm_for_name(prompt, client, model, debug_filename, step) -def process_lines_file(input_path: Path, client: OpenAI, model: str) -> Path: +def process_lines_file(input_path: Path, client: OpenAI, model: str, force: bool = False) -> Path: """Process a single lines file using multi-step approach.""" + progress = load_progress() + filename = input_path.name + + # Check if already processed + if not force and filename in progress and progress[filename].get("status") == "completed": + output_path = Path(progress[filename]["output_file"]) + if output_path.exists(): + print(f"Skipping {filename} (already processed)") + return output_path + print(f"\n{'='*50}") print(f"Processing: {input_path.name}") print(f"{'='*50}") @@ -343,21 +329,21 @@ def process_lines_file(input_path: Path, client: OpenAI, model: str) -> Path: print(" No valid lines found!") return None - # Classify speakers - song_speakers, dialogue_speakers = classify_speakers_by_time(lines) - print(f" Dialogue speakers: {', '.join(sorted(dialogue_speakers))}") + # Get unique speakers (excluding "Song" - already known) + all_speakers = set(speaker for _, speaker, _ in lines) + speakers_to_identify = [s for s in all_speakers if s != "Song"] - # Build mapping starting with song speakers + print(f" Speakers to identify: {', '.join(sorted(speakers_to_identify))}") + + # Build mapping final_mapping = {} - for speaker in song_speakers: - final_mapping[speaker] = "Song" - if not dialogue_speakers: - print(f" All lines are within first 15 seconds (opening song)") + if not speakers_to_identify: + print(f" No speakers to identify (only Song present)") else: # Separate regular speakers from unknown/merged speakers (like "?") - regular_speakers = [s for s in dialogue_speakers if s.isalpha()] - unknown_speakers = [s for s in dialogue_speakers if not s.isalpha()] + regular_speakers = [s for s in speakers_to_identify if s.isalpha()] + unknown_speakers = [s for s in speakers_to_identify if not s.isalpha()] # Step 1: Identify Malabar (from regular speakers only) print(f" Step 1: Identifying Malabar...") @@ -414,6 +400,14 @@ def process_lines_file(input_path: Path, client: OpenAI, model: str) -> Path: with open(output_path, 'w', encoding='utf-8') as f: f.write(output_text) + # Update progress + progress[filename] = { + "status": "completed", + "output_file": str(output_path), + "speaker_mapping": final_mapping + } + save_progress(progress) + print(f" Saved to: {output_path}") return output_path @@ -422,13 +416,13 @@ def process_lines_file(input_path: Path, client: OpenAI, model: str) -> Path: def apply_speaker_names(lines: List[Tuple[str, str, str]], mapping: Dict[str, str]) -> str: """Apply speaker names to lines. - SPECIAL: Lines in first 15 seconds are labeled as "Song" (opening theme). + SPECIAL: "Song" speaker is passed through unchanged (already labeled in Step 2). """ result_lines = [] for timestamp, speaker, text in lines: - # Check if this line is in the first 15 seconds - if parse_timestamp(timestamp) <= 15: + # "Song" speaker is already correctly labeled - pass through unchanged + if speaker == "Song": speaker_name = "Song" else: speaker_name = mapping.get(speaker, f"Speaker_{speaker}") @@ -448,6 +442,9 @@ def get_input_files() -> list[Path]: def main(): ensure_dirs() + # Check for force flag + force = "--force" in sys.argv or "-f" in sys.argv + # Get LLM config base_url, model = get_llm_config() client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=base_url) @@ -460,6 +457,8 @@ def main(): sys.exit(1) print(f"Found {len(lines_files)} transcript(s) in {INPUT_DIR}/") + if force: + print("Force mode: ON (reprocessing all files)") print(f"Debug info will be saved to {DEBUG_DIR}/") print("") @@ -469,10 +468,13 @@ def main(): for input_path in lines_files: try: - output_path = process_lines_file(input_path, client, model) + output_path = process_lines_file(input_path, client, model, force=force) if output_path: success_count += 1 except Exception as e: + progress = load_progress() + progress[input_path.name] = {"status": "error", "error": str(e)} + save_progress(progress) print(f"\n❌ Failed to process {input_path.name}: {e}") import traceback traceback.print_exc()