whisper.cpp/models/convert-silero-vad-to-ggml.py

197 lines
7.1 KiB
Python

import os
import struct
import argparse
import torch
import numpy as np
from silero_vad import load_silero_vad, __version__ as silero_version
def convert_silero_vad(output_path, print_tensors=True):
model = load_silero_vad()
state_dict = model.state_dict()
# Clean up state dict keys - filter out 8k model
cleaned_dict = {}
for key, value in state_dict.items():
# Skip 8k model
if "_8k" not in key:
clean_key = key
if not key.startswith("_model."):
clean_key = "_model." + key
cleaned_dict[clean_key] = value
base, ext = os.path.splitext(output_path)
output_file = f"{base}-v{silero_version}-ggml{ext}"
print(f"Saving GGML Silero-VAD model to {output_file}")
print("\nTensor info for debugging:")
for key, tensor in cleaned_dict.items():
print(f" - {key}: {tensor.shape} ({tensor.dtype})")
print()
with open(output_file, "wb") as fout:
# Write magic and version
fout.write(struct.pack("i", 0x67676d6c))
model_type = "silero-16k"
str_len = len(model_type)
fout.write(struct.pack("i", str_len))
fout.write(model_type.encode('utf-8'))
version_parts = silero_version.split('.')
major, minor, patch = map(int, version_parts)
print(f"Version: {major}.{minor}.{patch}")
fout.write(struct.pack("i", major))
fout.write(struct.pack("i", minor))
fout.write(struct.pack("i", patch))
# Write model architecture parameters
window_size = 512
fout.write(struct.pack("i", window_size))
context_size = 64
fout.write(struct.pack("i", context_size))
n_encoder_layers = 4
fout.write(struct.pack("i", n_encoder_layers))
# Write encoder dimensions
input_channels = 129
encoder_in_channels = [input_channels, 128, 64, 64]
encoder_out_channels = [128, 64, 64, 128]
kernel_size = 3
for i in range(n_encoder_layers):
fout.write(struct.pack("i", encoder_in_channels[i]))
fout.write(struct.pack("i", encoder_out_channels[i]))
fout.write(struct.pack("i", kernel_size))
# Write LSTM dimensions
lstm_input_size = 128
lstm_hidden_size = 128
fout.write(struct.pack("i", lstm_input_size))
fout.write(struct.pack("i", lstm_hidden_size))
# Write final conv dimensions
final_conv_in = 128
final_conv_out = 1
fout.write(struct.pack("i", final_conv_in))
fout.write(struct.pack("i", final_conv_out))
# Define tensor keys to write
tensor_keys = []
# Encoder weights
for i in range(n_encoder_layers):
weight_key = f"_model.encoder.{i}.reparam_conv.weight"
bias_key = f"_model.encoder.{i}.reparam_conv.bias"
if weight_key in cleaned_dict and bias_key in cleaned_dict:
tensor_keys.append(weight_key)
tensor_keys.append(bias_key)
# LSTM weights
lstm_keys = [
"_model.decoder.rnn.weight_ih",
"_model.decoder.rnn.weight_hh",
"_model.decoder.rnn.bias_ih",
"_model.decoder.rnn.bias_hh"
]
tensor_keys.extend([k for k in lstm_keys if k in cleaned_dict])
# Final conv weights
final_keys = [
"_model.decoder.decoder.2.weight",
"_model.decoder.decoder.2.bias"
]
tensor_keys.extend([k for k in final_keys if k in cleaned_dict])
# STFT basis - add this last
stft_tensor = "_model.stft.forward_basis_buffer"
tensor_keys.append(stft_tensor)
print(f"Writing {len(tensor_keys)} tensors:")
for key in tensor_keys:
if key in cleaned_dict:
print(f" - {key}: {cleaned_dict[key].shape}")
else:
print(f" - {key}: MISSING")
# Process each tensor
for key in tensor_keys:
if key not in cleaned_dict:
print(f"Warning: Missing tensor {key}, skipping")
continue
tensor = cleaned_dict[key]
# Special handling for STFT tensor
if key == "_model.stft.forward_basis_buffer":
# Get the original numpy array without squeezing
data = tensor.detach().cpu().numpy()
# Ensure it has the expected shape
print(f"STFT tensor original shape: {data.shape}")
n_dims = 3
tensor_shape = [data.shape[2], data.shape[1], data.shape[0]]
is_conv_weight = True
else:
# For other tensors, we can use standard processing
data = tensor.detach().cpu().squeeze().numpy()
tensor_shape = list(data.shape)
# Ensure we have at most 4 dimensions for GGML
n_dims = min(len(tensor_shape), 4)
# Reverse dimensions for GGML
tensor_shape = tensor_shape[:n_dims]
tensor_shape.reverse()
# Check if this is a convolution weight tensor
is_conv_weight = "weight" in key and ("encoder" in key or "_model.decoder.decoder.2" in key)
# Convert to float16 for convolution weights
if is_conv_weight:
data = data.astype(np.float16)
ftype = 1 # float16
else:
ftype = 0 # float32
# Debug printing of tensor info
print(f"\nWriting tensor: {key}")
print(f" Original shape: {tensor.shape}")
print(f" Processed shape: {data.shape}")
print(f" GGML dimensions: {n_dims}")
print(f" GGML shape: {tensor_shape}")
print(f" Type: {'float16' if ftype == 1 else 'float32'}")
# Convert tensor name to bytes
name_bytes = key.encode('utf-8')
name_length = len(name_bytes)
# Write tensor header
fout.write(struct.pack("i", n_dims))
fout.write(struct.pack("i", name_length))
fout.write(struct.pack("i", ftype))
# Write tensor dimensions
for i in range(n_dims):
size = tensor_shape[i] if i < len(tensor_shape) else 1
fout.write(struct.pack("i", size))
print(f" Writing dimension {i}: {size}")
# Write tensor name
fout.write(name_bytes)
# Write tensor data
data.tofile(fout)
print(f" Wrote {data.size * (2 if ftype==1 else 4)} bytes")
print(f"\nDone! Model has been converted to GGML format: {output_file}")
print(f"File size: {os.path.getsize(output_file)} bytes")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Silero-VAD PyTorch model to GGML format")
parser.add_argument("--output", type=str, required=True, help="Path to output GGML model file")
parser.add_argument("--print-tensors", action="store_true", help="Print tensor values", default=True)
args = parser.parse_args()
convert_silero_vad(args.output, args.print_tensors)