char5742 / flutter_silero_vad

This is an unofficial plugin for calling the Silero VAD ONNX model in Flutter.
Other
26 stars 15 forks source link

Error: Shape Mismatch #17

Open parkjihwanjay opened 1 month ago

parkjihwanjay commented 1 month ago

I encountered this error

execution_frame.cc:857 VerifyOutputSizes] Expected shape from model of {1,-1,128} does not match actual shape of {1,1,128,6} for output If_0_then_branch__Inline_0__/Unsqueeze_1_output_0

Assuming it's related to model shape. I followed same process of codes in example of this repository. Is there any way to fix this error?

Thanks for your great work! Hope this project gets developed.

Regards,

parkjihwanjay commented 1 month ago
import 'dart:async';
import 'dart:io';
import 'dart:typed_data';

import 'package:audio_session/audio_session.dart';
import 'package:flutter/material.dart';
import 'package:flutter_silero_vad/flutter_silero_vad.dart';
import 'package:flutter_sound/flutter_sound.dart';
import 'package:flutter/services.dart';
import 'package:permission_handler/permission_handler.dart';

import 'package:path_provider/path_provider.dart';

import 'package:web_socket_channel/web_socket_channel.dart';
import 'package:web_socket_channel/status.dart' as status;

import 'package:audio_streamer/audio_streamer.dart';

void main() {
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      title: 'Voice Recorder',
      theme: ThemeData(
        primarySwatch: Colors.blue,
      ),
      home: RecorderScreen(),
    );
  }
}

Int16List _transformBuffer(List<int> buffer) {
  final bytes = Uint8List.fromList(buffer);
  return Int16List.view(bytes.buffer);
}

class RecorderScreen extends StatefulWidget {
  @override
  _RecorderScreenState createState() => _RecorderScreenState();
}

class _RecorderScreenState extends State<RecorderScreen> {
  final FlutterSoundRecorder _recorder = FlutterSoundRecorder();
  final FlutterSoundPlayer _player = FlutterSoundPlayer();

  final recorder = AudioStreamer.instance;

  StreamSubscription<List<int>>? recordingDataSubscription;

  final vad = FlutterSileroVad();

  final SAMPLE_RATE = 16000;
  final FRAME_SIZE = 200;
  final int bitsPerSample = 16;
  final int numChannels = 1;

  Future<String> get modelPath async =>
      '${(await getApplicationSupportDirectory()).path}/silero_vad.v5.onnx';

  Future<void> onnxModelToLocal() async {
    final data = await rootBundle.load('assets/silero_vad.v5.onnx');
    final bytes =
        data.buffer.asUint8List(data.offsetInBytes, data.lengthInBytes);
    File(await modelPath).writeAsBytesSync(bytes);
  }

  late WebSocketChannel _channel;

  bool _isRecording = false;

  late StreamController<Uint8List> _streamController;

  final frameBuffer = <int>[];

  @override
  void initState() {
    super.initState();
    _initializeRecorder();
  }

  Future<void> _initializeVad() async {
    await onnxModelToLocal();
    await vad.initialize(
      modelPath: await modelPath,
      sampleRate: SAMPLE_RATE,
      frameSize: FRAME_SIZE,
      threshold: 0.7,
      minSilenceDurationMs: 100,
      speechPadMs: 0,
    );
  }

  Future<void> _connectWebSocket() async {
    // 같은 네트워크에 있는 웹소켓 서버
    final wsUrl = Uri.parse('ws://192.168.0.174:8000/ws');
    _channel = WebSocketChannel.connect(wsUrl);
    await _channel.ready;

    _channel.stream.listen((data) {
      _playReceivedAudio(data);
    });
  }

  Future<void> _playReceivedAudio(Uint8List audioData) async {
    await _player.startPlayer(
      fromDataBuffer: audioData,
      codec: Codec.pcm16WAV,
    );
  }

  Future<void> _initializeRecorder() async {
    await Permission.microphone.request();
    final session = await AudioSession.instance;
    await session.configure(AudioSessionConfiguration(
      avAudioSessionCategory: AVAudioSessionCategory.playAndRecord,
      avAudioSessionCategoryOptions:
          AVAudioSessionCategoryOptions.allowBluetooth |
              AVAudioSessionCategoryOptions.defaultToSpeaker,
      avAudioSessionMode: AVAudioSessionMode.voiceChat,
      avAudioSessionRouteSharingPolicy:
          AVAudioSessionRouteSharingPolicy.defaultPolicy,
      avAudioSessionSetActiveOptions: AVAudioSessionSetActiveOptions.none,
      androidAudioAttributes: const AndroidAudioAttributes(
        contentType: AndroidAudioContentType.speech,
        flags: AndroidAudioFlags.none,
        usage: AndroidAudioUsage.voiceCommunication,
      ),
      androidAudioFocusGainType: AndroidAudioFocusGainType.gain,
      androidWillPauseWhenDucked: true,
    ));
    await _recorder.openRecorder();
    await _player.openPlayer();
  }

  Future<void> _startRecording() async {
    // _streamController = StreamController<Uint8List>();
    // _recorder.setSubscriptionDuration(const Duration(milliseconds: 10));

    await _initializeVad();

    recordingDataSubscription = recorder.audioStream.listen((buffer) async {
      final data = _transformBuffer(buffer);
      if (data.isEmpty) return;

      frameBuffer.addAll(data);

      while(frameBuffer.length >= FRAME_SIZE * 2 * SAMPLE_RATE ~/ 1000) {
        print('VAD 실행..');
        final b = frameBuffer.take(FRAME_SIZE * 2 * SAMPLE_RATE ~/ 1000).toList();
        frameBuffer.removeRange(0, FRAME_SIZE * 2 * SAMPLE_RATE ~/ 1000);
        await _handleProcessedAudio(b);
      }
    });

    await recorder.startRecording();

    // _streamController.stream.listen((buffer) async {
    //   final byteBuffer = buffer.buffer;

    //   final data = Int16List.view(byteBuffer);
    //   if (data.isEmpty) return;

    //   frameBuffer.addAll(buffer);

    //   while(frameBuffer.length >= FRAME_SIZE * 2 * SAMPLE_RATE ~/ 1000) {
    //     print('VAD 실행..');
    //     final b = frameBuffer.take(FRAME_SIZE * 2 * SAMPLE_RATE ~/ 1000).toList();
    //     frameBuffer.removeRange(0, FRAME_SIZE * 2 * SAMPLE_RATE ~/ 1000);
    //     await _handleProcessedAudio(b);
    //   }
    // });
    await _connectWebSocket();
    // await _recorder.startRecorder(
    //   codec:  Codec.pcm16,
    //   toStream: _streamController.sink,
    //   numChannels: numChannels,
    //   sampleRate: SAMPLE_RATE,
    // );

    print("Started recording");
    setState(() {
      _isRecording = true;
    });
  }

  Future<void> _handleProcessedAudio(List<int> buffer) async {
    final transformedBuffer = _transformBuffer(buffer);
    final transformedBufferFloat =
        transformedBuffer.map((e) => e / 32768).toList();

    final isActivated = await vad.predict(Float32List.fromList(transformedBufferFloat));

    print("VAD result: $isActivated");

    _channel.sink.add(buffer);
  }

  Future<void> _stopRecording() async {
    await _recorder.stopRecorder();
    await _player.stopPlayer();
    await _streamController.close();
    // await _channel.sink.close();
    _channel.sink.add('STOP_RECORDING');
    setState(() {
      _isRecording = false;
    });
  }

  @override
  void dispose() {
    _recorder.closeRecorder();
    _player.closePlayer();
    super.dispose();
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(
        title: Text('Voice Recorder'),
      ),
      body: Center(
        child: Column(
          mainAxisAlignment: MainAxisAlignment.center,
          children: [
            if (!_isRecording)
              _buildStartRecordingButton()
            else
              _buildStopRecordingButton(),
          ],
        ),
      ),
    );
  }

  Widget _buildStopRecordingButton() {
    return ElevatedButton(
      onPressed: _stopRecording,
      child: Text('Stop Recording'),
    );
  }

  Widget _buildStartRecordingButton() {
    return ElevatedButton(
      onPressed: _startRecording,
      child: Text('Start Recording'),
    );
  }
}

this is my code