You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
432 lines
14 KiB
432 lines
14 KiB
import os
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from fastapi import APIRouter, UploadFile, File, HTTPException
|
|
from fastapi.responses import Response
|
|
from pydantic import BaseModel
|
|
import mido
|
|
from ..core.file_handling import load_midi_from_bytes, midi_to_bytes
|
|
from ..core.analyze import analyze_midi
|
|
from ..core.track_detail import get_track_detail
|
|
from ..core.midi_utils import has_musical_messages, get_instrument_name
|
|
from ..core import baketempo as baketempo_core
|
|
from ..core import monofy as monofy_core
|
|
from ..core import reduncheck as reduncheck_core
|
|
from ..core import velfix as velfix_core
|
|
from ..core import type0 as type0_core
|
|
|
|
router = APIRouter(prefix="/api/session")
|
|
|
|
# In-memory session store
|
|
sessions: dict[str, "Session"] = {}
|
|
|
|
|
|
@dataclass
|
|
class Session:
|
|
midi_bytes: bytes
|
|
original_name: str
|
|
undo_stack: list[bytes] = field(default_factory=list)
|
|
history: list[str] = field(default_factory=list)
|
|
|
|
|
|
class ApplyRequest(BaseModel):
|
|
tool: str
|
|
channels: list[int] | None = None
|
|
vel_min: int | None = None
|
|
vel_max: int | None = None
|
|
tracks: list[int] | None = None
|
|
|
|
|
|
class TrackEditRequest(BaseModel):
|
|
channel: int | None = None
|
|
program: int | None = None
|
|
|
|
|
|
class MergeRequest(BaseModel):
|
|
tracks: list[int]
|
|
|
|
|
|
def _musical_to_raw_indices(midi, musical_indices: set[int]) -> set[int]:
|
|
"""Convert 0-based musical track indices to raw MIDI file track indices."""
|
|
raw = set()
|
|
musical_idx = 0
|
|
for i, track in enumerate(midi.tracks):
|
|
if not has_musical_messages(track):
|
|
continue
|
|
if musical_idx in musical_indices:
|
|
raw.add(i)
|
|
musical_idx += 1
|
|
return raw
|
|
|
|
|
|
def _find_musical_track(midi, track_index: int):
|
|
"""Find a musical track by its 0-based musical index. Returns (raw_index, track) or (None, None)."""
|
|
musical_idx = 0
|
|
for i, track in enumerate(midi.tracks):
|
|
if not has_musical_messages(track):
|
|
continue
|
|
if musical_idx == track_index:
|
|
return i, track
|
|
musical_idx += 1
|
|
return None, None
|
|
|
|
|
|
def _analyze_session(session: Session) -> dict:
|
|
midi = load_midi_from_bytes(session.midi_bytes)
|
|
filename = os.path.splitext(session.original_name)[0]
|
|
return analyze_midi(midi, filename)
|
|
|
|
|
|
@router.post("/upload")
|
|
async def upload(file: UploadFile = File(...)):
|
|
if not file.filename.lower().endswith(('.mid', '.midi')):
|
|
raise HTTPException(400, "File must be a .mid or .midi file")
|
|
|
|
content = await file.read()
|
|
try:
|
|
midi = load_midi_from_bytes(content)
|
|
except Exception as e:
|
|
raise HTTPException(400, f"Invalid MIDI file: {e}")
|
|
|
|
session_id = str(uuid.uuid4())
|
|
session = Session(midi_bytes=content, original_name=file.filename)
|
|
sessions[session_id] = session
|
|
|
|
filename = os.path.splitext(file.filename)[0]
|
|
analysis = analyze_midi(midi, filename)
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"analysis": analysis,
|
|
"history": []
|
|
}
|
|
|
|
|
|
@router.post("/{session_id}/apply")
|
|
async def apply_tool(session_id: str, request: ApplyRequest):
|
|
session = sessions.get(session_id)
|
|
if not session:
|
|
raise HTTPException(404, "Session not found")
|
|
|
|
channels = set(request.channels) if request.channels else None
|
|
|
|
# Push current state to undo stack
|
|
session.undo_stack.append(session.midi_bytes)
|
|
|
|
try:
|
|
midi = load_midi_from_bytes(session.midi_bytes)
|
|
except Exception as e:
|
|
session.undo_stack.pop()
|
|
raise HTTPException(500, f"Failed to load MIDI: {e}")
|
|
|
|
# Build history label
|
|
label = _tool_label(request)
|
|
|
|
try:
|
|
if request.tool == "baketempo":
|
|
result = baketempo_core.process(midi)
|
|
elif request.tool == "monofy":
|
|
raw_tracks = None
|
|
if request.tracks is not None:
|
|
raw_tracks = _musical_to_raw_indices(midi, set(request.tracks))
|
|
result = monofy_core.process(midi, raw_tracks)
|
|
elif request.tool == "reduncheck":
|
|
raw_tracks = None
|
|
if request.tracks is not None:
|
|
raw_tracks = _musical_to_raw_indices(midi, set(request.tracks))
|
|
result = reduncheck_core.process(midi, raw_tracks)
|
|
elif request.tool == "velfix":
|
|
if request.vel_min is None or request.vel_max is None:
|
|
session.undo_stack.pop()
|
|
raise HTTPException(400, "vel_min and vel_max are required for velfix")
|
|
if not (0 <= request.vel_min <= 127) or not (0 <= request.vel_max <= 127):
|
|
session.undo_stack.pop()
|
|
raise HTTPException(400, "velocities must be 0-127")
|
|
if request.vel_min > request.vel_max:
|
|
session.undo_stack.pop()
|
|
raise HTTPException(400, "vel_min must be <= vel_max")
|
|
raw_tracks = None
|
|
if request.tracks is not None:
|
|
raw_tracks = _musical_to_raw_indices(midi, set(request.tracks))
|
|
result = velfix_core.process(midi, request.vel_min, request.vel_max, raw_tracks)
|
|
elif request.tool == "type0":
|
|
result = type0_core.process(midi)
|
|
else:
|
|
session.undo_stack.pop()
|
|
raise HTTPException(400, f"Unknown tool: {request.tool}")
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
session.undo_stack.pop()
|
|
raise HTTPException(500, f"Processing error: {e}")
|
|
|
|
session.midi_bytes = midi_to_bytes(result)
|
|
session.history.append(label)
|
|
|
|
analysis = _analyze_session(session)
|
|
return {
|
|
"analysis": analysis,
|
|
"history": session.history
|
|
}
|
|
|
|
|
|
@router.post("/{session_id}/undo")
|
|
async def undo(session_id: str):
|
|
session = sessions.get(session_id)
|
|
if not session:
|
|
raise HTTPException(404, "Session not found")
|
|
|
|
if not session.undo_stack:
|
|
raise HTTPException(400, "Nothing to undo")
|
|
|
|
session.midi_bytes = session.undo_stack.pop()
|
|
session.history.pop()
|
|
|
|
analysis = _analyze_session(session)
|
|
return {
|
|
"analysis": analysis,
|
|
"history": session.history
|
|
}
|
|
|
|
|
|
@router.get("/{session_id}/download")
|
|
async def download(session_id: str):
|
|
session = sessions.get(session_id)
|
|
if not session:
|
|
raise HTTPException(404, "Session not found")
|
|
|
|
base = os.path.splitext(session.original_name)[0]
|
|
filename = f"{base}_edited.mid" if session.history else session.original_name
|
|
|
|
return Response(
|
|
content=session.midi_bytes,
|
|
media_type="audio/midi",
|
|
headers={"Content-Disposition": f'attachment; filename="{filename}"'}
|
|
)
|
|
|
|
|
|
@router.get("/{session_id}/track/{track_index}")
|
|
async def track_detail(session_id: str, track_index: int):
|
|
session = sessions.get(session_id)
|
|
if not session:
|
|
raise HTTPException(404, "Session not found")
|
|
|
|
midi = load_midi_from_bytes(session.midi_bytes)
|
|
detail = get_track_detail(midi, track_index)
|
|
if detail is None:
|
|
raise HTTPException(404, "Track not found")
|
|
|
|
return detail
|
|
|
|
|
|
@router.post("/{session_id}/track/{track_index}/edit")
|
|
async def edit_track(session_id: str, track_index: int, request: TrackEditRequest):
|
|
session = sessions.get(session_id)
|
|
if not session:
|
|
raise HTTPException(404, "Session not found")
|
|
|
|
if request.channel is not None and not (1 <= request.channel <= 16):
|
|
raise HTTPException(400, "Channel must be 1-16")
|
|
if request.program is not None and not (0 <= request.program <= 127):
|
|
raise HTTPException(400, "Program must be 0-127")
|
|
|
|
session.undo_stack.append(session.midi_bytes)
|
|
|
|
try:
|
|
midi = load_midi_from_bytes(session.midi_bytes)
|
|
except Exception as e:
|
|
session.undo_stack.pop()
|
|
raise HTTPException(500, f"Failed to load MIDI: {e}")
|
|
|
|
raw_idx, target_track = _find_musical_track(midi, track_index)
|
|
if target_track is None:
|
|
session.undo_stack.pop()
|
|
raise HTTPException(404, "Track not found")
|
|
|
|
# Get track name for history label
|
|
track_name = f"Track {track_index + 1}"
|
|
for msg in target_track:
|
|
if msg.type == 'track_name':
|
|
track_name = msg.name
|
|
break
|
|
|
|
label_parts = []
|
|
|
|
if request.channel is not None:
|
|
new_channel = request.channel - 1
|
|
old_channels = set()
|
|
for msg in target_track:
|
|
if hasattr(msg, 'channel'):
|
|
old_channels.add(msg.channel + 1)
|
|
old_ch_str = ",".join(str(c) for c in sorted(old_channels)) if old_channels else "?"
|
|
for msg in target_track:
|
|
if hasattr(msg, 'channel'):
|
|
msg.channel = new_channel
|
|
label_parts.append(f"CH {old_ch_str} \u2192 {request.channel}")
|
|
|
|
if request.program is not None:
|
|
instrument_name = get_instrument_name(request.program)
|
|
found_pc = False
|
|
for msg in target_track:
|
|
if msg.type == 'program_change':
|
|
msg.program = request.program
|
|
found_pc = True
|
|
break
|
|
if not found_pc:
|
|
ch = request.channel - 1 if request.channel else 0
|
|
for msg in target_track:
|
|
if hasattr(msg, 'channel'):
|
|
ch = msg.channel
|
|
break
|
|
pc_msg = mido.Message('program_change', program=request.program, channel=ch, time=0)
|
|
insert_idx = 0
|
|
for j, msg in enumerate(target_track):
|
|
if msg.is_meta:
|
|
insert_idx = j + 1
|
|
else:
|
|
break
|
|
target_track.insert(insert_idx, pc_msg)
|
|
label_parts.append(f"Instrument \u2192 {instrument_name}")
|
|
|
|
label = f"{track_name}: {', '.join(label_parts)}"
|
|
session.midi_bytes = midi_to_bytes(midi)
|
|
session.history.append(label)
|
|
|
|
analysis = _analyze_session(session)
|
|
return {
|
|
"analysis": analysis,
|
|
"history": session.history
|
|
}
|
|
|
|
|
|
@router.post("/{session_id}/track/{track_index}/delete")
|
|
async def delete_track(session_id: str, track_index: int):
|
|
session = sessions.get(session_id)
|
|
if not session:
|
|
raise HTTPException(404, "Session not found")
|
|
|
|
session.undo_stack.append(session.midi_bytes)
|
|
|
|
try:
|
|
midi = load_midi_from_bytes(session.midi_bytes)
|
|
except Exception as e:
|
|
session.undo_stack.pop()
|
|
raise HTTPException(500, f"Failed to load MIDI: {e}")
|
|
|
|
raw_idx, target_track = _find_musical_track(midi, track_index)
|
|
if target_track is None:
|
|
session.undo_stack.pop()
|
|
raise HTTPException(404, "Track not found")
|
|
|
|
track_name = f"Track {track_index + 1}"
|
|
for msg in target_track:
|
|
if msg.type == 'track_name':
|
|
track_name = msg.name
|
|
break
|
|
|
|
midi.tracks.pop(raw_idx)
|
|
|
|
label = f"Delete {track_name}"
|
|
session.midi_bytes = midi_to_bytes(midi)
|
|
session.history.append(label)
|
|
|
|
analysis = _analyze_session(session)
|
|
return {
|
|
"analysis": analysis,
|
|
"history": session.history
|
|
}
|
|
|
|
|
|
@router.post("/{session_id}/merge")
|
|
async def merge_tracks(session_id: str, request: MergeRequest):
|
|
session = sessions.get(session_id)
|
|
if not session:
|
|
raise HTTPException(404, "Session not found")
|
|
|
|
if len(request.tracks) < 2:
|
|
raise HTTPException(400, "At least 2 tracks required for merge")
|
|
|
|
session.undo_stack.append(session.midi_bytes)
|
|
|
|
try:
|
|
midi = load_midi_from_bytes(session.midi_bytes)
|
|
except Exception as e:
|
|
session.undo_stack.pop()
|
|
raise HTTPException(500, f"Failed to load MIDI: {e}")
|
|
|
|
# Collect raw indices and track references
|
|
raw_indices = []
|
|
first_track_name = None
|
|
for musical_idx in sorted(request.tracks):
|
|
raw_idx, track = _find_musical_track(midi, musical_idx)
|
|
if track is None:
|
|
session.undo_stack.pop()
|
|
raise HTTPException(404, f"Track {musical_idx} not found")
|
|
raw_indices.append(raw_idx)
|
|
if first_track_name is None:
|
|
for msg in track:
|
|
if msg.type == 'track_name':
|
|
first_track_name = msg.name
|
|
break
|
|
|
|
if first_track_name is None:
|
|
first_track_name = "Merged"
|
|
|
|
# Collect all events with absolute timing
|
|
all_events = []
|
|
for raw_idx in raw_indices:
|
|
track = midi.tracks[raw_idx]
|
|
absolute_time = 0
|
|
for msg in track:
|
|
absolute_time += msg.time
|
|
if msg.type == 'track_name':
|
|
continue # Skip track names from source tracks
|
|
all_events.append((absolute_time, msg.copy(time=0)))
|
|
|
|
# Sort by absolute time and convert to delta
|
|
all_events.sort(key=lambda x: x[0])
|
|
merged_track = mido.MidiTrack()
|
|
merged_track.append(mido.MetaMessage('track_name', name=first_track_name, time=0))
|
|
prev_time = 0
|
|
for abs_time, msg in all_events:
|
|
msg.time = abs_time - prev_time
|
|
merged_track.append(msg)
|
|
prev_time = abs_time
|
|
|
|
# Remove old tracks in reverse order, then insert merged
|
|
for raw_idx in sorted(raw_indices, reverse=True):
|
|
midi.tracks.pop(raw_idx)
|
|
insert_pos = min(raw_indices)
|
|
midi.tracks.insert(insert_pos, merged_track)
|
|
|
|
track_nums = ", ".join(str(t + 1) for t in sorted(request.tracks))
|
|
label = f"Merge Tracks {track_nums}"
|
|
session.midi_bytes = midi_to_bytes(midi)
|
|
session.history.append(label)
|
|
|
|
analysis = _analyze_session(session)
|
|
return {
|
|
"analysis": analysis,
|
|
"history": session.history
|
|
}
|
|
|
|
|
|
def _tool_label(request: ApplyRequest) -> str:
|
|
names = {
|
|
"baketempo": "Bake Tempo",
|
|
"monofy": "Monofy",
|
|
"reduncheck": "Remove Redundancy",
|
|
"velfix": "Velocity Fix",
|
|
"type0": "Convert to Type 0"
|
|
}
|
|
label = names.get(request.tool, request.tool)
|
|
parts = []
|
|
if request.channels:
|
|
parts.append(f"CH {','.join(str(c) for c in sorted(request.channels))}")
|
|
if request.tracks is not None:
|
|
parts.append(f"Tracks {','.join(str(t + 1) for t in sorted(request.tracks))}")
|
|
if request.vel_min is not None and request.vel_max is not None:
|
|
parts.append(f"vel={request.vel_min}-{request.vel_max}")
|
|
if parts:
|
|
label += f" ({'; '.join(parts)})"
|
|
return label
|