2022-07-28
来源:未知
责任编辑:Selina
人气:
核心提示:大家好,欢迎来到谷雨课堂。
本课内容:
大家好,欢迎来到谷雨课堂
import { tf } from './Network/Network.js';
import { LogStftFeature } from './Feature/DataParser.js';
import { PinYin } from './Label/pinyinbase.js';
import { baseURI } from "../../utils/other_utils.js";
class ASRModel {
constructor() {
// this.init();
};
init = async (ModelDir = './ASR/Model/Network/tensorflowjs/tfjsModel/tfjs_mobilev3small_thchs30/') => {
ModelDir = new URL(ModelDir, baseURI).href;
await this.initFeature(ModelDir);
await this.initPinYin();
await this.initNetwork(ModelDir);
this.ModelDir = ModelDir;
};
initPinYin = async () => {
this.pinyin = new PinYin(undefined);
await this.pinyin.init();
};
initFeature = async (ModelDir) => {
console.log(`准备加载feature...`);
const response = await fetch(ModelDir + 'feature.json');
this.featureConfig = await response.json();
this.feature = new LogStftFeature(this.featureConfig.sampleRate, this.featureConfig.fft_s, this.featureConfig.hop_s);
};
initNetwork = async (ModelDir) => {
console.log(`准备加载model...`);
console.log(`当前tensorflowJS的Backend:${tf.getBackend()}`);
this.tfjs_model = await tf.loadGraphModel(ModelDir + 'model.json');
console.log(`model加载完成,进行model预热...`);
const preloadTimeN = 1024;
const frequencyN = Math.round(this.feature.fft_s * this.feature.sampleRate / 2) + 1;
const preloadRes = this.tfjs_model.predict(tf.zeros([1, preloadTimeN, frequencyN]));
const viewK = preloadTimeN / preloadRes.shape[1];
console.log(`model时间维视野缩放比为${viewK},因此单个拼音输出视野时长为${this.feature.hop_s * viewK}s`);
console.log(`model预热完成`);
this.viewK = viewK;
this.eachOutPutTime = viewK * this.feature.hop_s;
};
predictAudioData = async (audioData) => {
const logstftData = this.feature.logstft_audioData(audioData);
return await this.predictStftData(logstftData);
};
predictStftData = async (stftData) => {
const onebatch_stft_tfTensor = tf.tensor(stftData.stft.typedArrayView, [1, stftData.stft.rowsN, stftData.stft.columnsN]);
const predict_res = await this.tfjs_model.execute(onebatch_stft_tfTensor);
const softmax_res = predict_res.squeeze(0).softmax();
const argmax_res_array = softmax_res.argMax(-1).arraySync();
const pinyinArray = argmax_res_array.map(max_arg => this.pinyin.num2py(max_arg));
const predictResult = {
// 'softmax_resArray':softmax_res.arraySync(),
'pinyinArray': pinyinArray,
'audioEndTime': stftData.audioEndTime,
'audioStartTime': stftData.audioStartTime,
'timeLength': stftData.timeLength,
};
return predictResult;
};
};
export { ASRModel };
import { Drawer } from '../Drawer/Drawer.js';
import { AudioData, StftData, AudioDataCyclicContainer, StftDataCyclicContainer } from './AudioContainer.js';
import { CyclicFloat32Array, CyclicImageData, Float32Matrix } from '../utils/CyclicContainer.js';
export const TIME_AREA_H = 20;
const RGBcolors = [[255, 255, 240], [255, 240, 245], [0, 191, 255], [160, 32, 240]];
/**
*
* @param {Number} num 0-1之间的浮点数
*/
export function num2color(num) {
const RGBcolor = convert_to_rgb(num, RGBcolors);
const RGBAcolor = new Uint8ClampedArray(4);
RGBAcolor[0] = RGBcolor[0];
RGBAcolor[1] = RGBcolor[1];
RGBAcolor[2] = RGBcolor[2];
RGBAcolor[3] = 255;
return RGBAcolor;
};
/**
*
* @param {Number} val
* @param {Array[Uint8ClampedArray(3)]} RGBcolors Example:[[255, 255, 240], [255, 240, 245], [0, 191, 255], [160, 32, 240]]
* @param {Number} min_val
* @param {Number} max_val
*/
function convert_to_rgb(val, RGBcolors, min_val = 0, max_val = 1) {
val = (val - min_val) / (max_val - min_val);
const i_f = val * (RGBcolors.length - 1);
const i = Math.floor(i_f / 1), f = i_f % 1; // Split into whole & fractional parts.
if (f == 0) {
return RGBcolors[i]
} else {
const [[r1, g1, b1], [r2, g2, b2]] = [RGBcolors[i], RGBcolors[i + 1]];
return [Math.round(r1 + f * (r2 - r1)), Math.round(g1 + f * (g2 - g1)), Math.round(b1 + f * (b2 - b1))];
};
};
function sin_one(x) {
return Math.sin(x * Math.PI / 2);
};
function circle_one(x) {
if (x == 0) return 0;
x = Math.abs(x);
const r = 0.7;
if (x > r) x = r;
const y = Math.sqrt(x * (2 * r - x)) + 1 - r;
return y
};
class WaveDrawer extends Drawer {
constructor(id = 'audioWave',
sampleRate = 8000,
numberOfChannels = 1,
total_duration = 10,
show_time = true,
) {
const sample_n_per_pixel = 64 * sampleRate / 8000;
const width = Math.floor(total_duration * sampleRate / sample_n_per_pixel);
const height = Math.round(width / 10);
super(id, width, height);
this.sampleRate = sampleRate;
this.numberOfChannels = numberOfChannels;
this.total_duration = total_duration;
this.show_time = show_time;
this.sample_nf_per_pixel = sample_n_per_pixel;
this.wave_area_length = show_time ? this.canvas.height - TIME_AREA_H : this.canvas.height;
this.leftedAudioData = null;
this.leftedAudioDataCyclicContainer = new AudioDataCyclicContainer(sampleRate, numberOfChannels, Math.ceil(this.sample_nf_per_pixel) / sampleRate);
this.cyclicImageData = new CyclicImageData(this.wave_area_length, this.canvas.width);
};
_checkAudioData = (audioData) => {
if (!(audioData instanceof AudioData)) throw new Error("传入的 audioData 类型不为 AudioData ");
else if (audioData.sampleRate !== this.sampleRate) throw new Error(`传入的 audioData.sampleRate(${audioData.sampleRate}) 与 WaveDrawer.sampleRate(${this.sampleRate}) 不相等`);
};
audioData2imageData = (audioData) => {
const leftedSampleLength = this.leftedAudioDataCyclicContainer.sampleLength;
const wave_image_length = Math.floor((audioData.sampleLength + leftedSampleLength) / this.sample_nf_per_pixel);
const wave_imgMatrix_count = new Float32Matrix(wave_image_length, this.wave_area_length);
const per_wave_len = this.wave_area_length / audioData.channels.length;
let row_i = 0;
let each_pixel_sampleGroup_begin = 0;
let each_imgMatrix_countRow_begin_i = 0;
if (row_i < wave_image_length) {
const leftedAudioData = this.leftedAudioDataCyclicContainer.getdata();
for (let k = 0; k < leftedAudioData.sampleLength; k++) {
for (let chN = 0; chN < leftedAudioData.channels.length; chN++) {
const cur_audio_sample = leftedAudioData.channels[chN][k];
const cur_colI = Math.round((cur_audio_sample * 0.5 + 0.5 + chN) * per_wave_len);
wave_imgMatrix_count.typedArrayView[cur_colI] += 1;
};
};
const nf_left = this.sample_nf_per_pixel - leftedSampleLength;
for (let k = 0; k < nf_left; k++) {
for (let chN = 0; chN < audioData.channels.length; chN++) {
const cur_audio_sample = audioData.channels[chN][k];
const cur_colI = Math.round((cur_audio_sample * 0.5 + 0.5 + chN) * per_wave_len);
wave_imgMatrix_count.typedArrayView[cur_colI] += 1;
};
};
row_i += 1;
each_pixel_sampleGroup_begin = this.sample_nf_per_pixel - leftedSampleLength;
each_imgMatrix_countRow_begin_i += this.wave_area_length;
while (row_i < wave_image_length) {
const start_k = Math.round(each_pixel_sampleGroup_begin);
const end_k = start_k + this.sample_nf_per_pixel;
for (let k = start_k; k < end_k; k++) {
for (let chN = 0; chN < audioData.channels.length; chN++) {
const cur_audio_sample = audioData.channels[chN][k];
const cur_colI = Math.round((cur_audio_sample * 0.5 + 0.5 + chN) * per_wave_len);
wave_imgMatrix_count.typedArrayView[each_imgMatrix_countRow_begin_i + cur_colI] += 1;
};
};
each_pixel_sampleGroup_begin += this.sample_nf_per_pixel;
each_imgMatrix_countRow_begin_i += this.wave_area_length;
row_i += 1;
};
};
this.leftedAudioDataCyclicContainer.cleardata();
this.leftedAudioDataCyclicContainer.updatedata(
new AudioData(
audioData.sampleRate,
audioData.channels.map(ch => ch.slice(each_pixel_sampleGroup_begin)),
audioData.audioEndTime,
),
);
if (!wave_image_length) return;
const imageData = new ImageData(this.wave_area_length, wave_image_length);
for (let i = 0; i < wave_imgMatrix_count.typedArrayView.length; i += 1) {
const p = i * 4;
const cur_pixel = circle_one(wave_imgMatrix_count.typedArrayView[i] / this.sample_nf_per_pixel);
const color = num2color(cur_pixel);
imageData.data[p + 0] = color[0]; // R value
imageData.data[p + 1] = color[1]; // G value
imageData.data[p + 2] = color[2]; // B value
imageData.data[p + 3] = color[3]; // A value
};
return imageData;
};
/**
*
* @param {AudioData} data 具有如下格式的对象:
* {
* sampleRate: Number, 音频采样率,单位Hz
* channels: Array[Float32Array], 数组,每个元素代表一个通道,
* 每个通道为浮点数数组。
* 每个通道长度应该相同。
* 每个通道中的每个元素为一个采样点。
* timeStamp: Date.now(), 音频末尾时间
* }
*/
updateAudioData = (audioData) => {
this._checkAudioData(audioData);
const imageData = this.audioData2imageData(audioData);
if (!imageData) return;
this.cyclicImageData.update(imageData);
this.setData(
{
cyclicImageData: this.cyclicImageData,
audioEndTime: audioData.audioEndTime,
}
);
};
draw = async ({ cyclicImageData, audioEndTime }) => {
const imageData = cyclicImageData.toImageDataT();
this.canvas_ctx.putImageData(imageData,
this.canvas.width - imageData.width, 0,
);
this.audioEndTime = audioEndTime;
if (this.show_time) {
const end_time = audioEndTime;
this.canvas_ctx.beginPath();
const dt = 0.5;
const time_dx = Math.round(dt * this.canvas.width / this.total_duration);
const s_y = this.canvas.height - 20, e_y = this.canvas.height - 10;
for (let i = 1; time_dx * i <= imageData.width; i += 1) {
this.canvas_ctx.moveTo(this.canvas.width - time_dx * i, s_y);
this.canvas_ctx.lineTo(this.canvas.width - time_dx * i, e_y);
this.canvas_ctx.fillText((end_time - dt * i).toFixed(3).toString(), this.canvas.width - time_dx * (i + 0.5), this.canvas.height);
};
this.canvas_ctx.stroke();
this.canvas_ctx.closePath();
};
};
};
;
class WaveDrawerFlexible extends Drawer {
constructor(id = 'audioWave',
width = null,
height = 125,
sampleRate = 8000,
numberOfChannels = 1,
total_duration = 10,
show_time = true,
) {
if (!width) width = Math.floor(total_duration * sampleRate / 64);
super(id, width, height);
this.sampleRate = sampleRate;
this.numberOfChannels = numberOfChannels;
this.total_duration = total_duration;
this.show_time = show_time;
this.sample_nf_per_pixel = this.total_duration * this.sampleRate / this.canvas.width;
this.wave_area_length = show_time ? this.canvas.height - TIME_AREA_H : this.canvas.height;
this.leftedAudioData = null;
this.leftedAudioDataCyclicContainer = new AudioDataCyclicContainer(sampleRate, numberOfChannels, Math.ceil(this.sample_nf_per_pixel) / sampleRate);
this.cyclicImageData = new CyclicImageData(this.wave_area_length, this.canvas.width);
};
_checkAudioData = (audioData) => {
if (!(audioData instanceof AudioData)) throw new Error("传入的 audioData 类型不为 AudioData ");
else if (audioData.sampleRate !== this.sampleRate) throw new Error(`传入的 audioData.sampleRate(${audioData.sampleRate}) 与 WaveDrawer.sampleRate(${this.sampleRate}) 不相等`);
};
audioData2imageData = (audioData) => {
const leftedSampleLength = this.leftedAudioData ? this.leftedAudioData.sampleLength : this.leftedAudioDataCyclicContainer.sampleLength;
const wave_image_length = Math.floor((audioData.sampleLength + leftedSampleLength) / this.sample_nf_per_pixel);
const wave_imgMatrix_count = new Float32Matrix(wave_image_length, this.wave_area_length);
const per_wave_len = this.wave_area_length / audioData.channels.length;
let row_i = 0;
let each_pixel_sampleGroup_begin = 0;
let each_imgMatrix_countRow_begin_i = 0;
if (row_i < wave_image_length) {
if (!this.leftedAudioData) {
this.leftedAudioData = this.leftedAudioDataCyclicContainer.getdata();
this.leftedAudioDataCyclicContainer.cleardata();
};
const leftedAudioData = this.leftedAudioData;
for (let k = 0; k < leftedAudioData.sampleLength; k++) {
for (let chN = 0; chN < leftedAudioData.channels.length; chN++) {
const cur_audio_sample = leftedAudioData.channels[chN][k];
const cur_colI = Math.round((cur_audio_sample * 0.5 + 0.5 + chN) * per_wave_len);
wave_imgMatrix_count.typedArrayView[cur_colI] += 1;
};
};
const nf_left = this.sample_nf_per_pixel - leftedSampleLength;
for (let k = 0; k < nf_left; k++) {
for (let chN = 0; chN < audioData.channels.length; chN++) {
const cur_audio_sample = audioData.channels[chN][k];
const cur_colI = Math.round((cur_audio_sample * 0.5 + 0.5 + chN) * per_wave_len);
wave_imgMatrix_count.typedArrayView[cur_colI] += 1;
};
};
row_i += 1;
each_pixel_sampleGroup_begin = this.sample_nf_per_pixel - leftedSampleLength;
each_imgMatrix_countRow_begin_i += this.wave_area_length;
while (row_i < wave_image_length) {
const start_k = Math.round(each_pixel_sampleGroup_begin);
const end_k = start_k + this.sample_nf_per_pixel;
for (let k = start_k; k < end_k; k++) {
for (let chN = 0; chN < audioData.channels.length; chN++) {
const cur_audio_sample = audioData.channels[chN][k];
const cur_colI = Math.round((cur_audio_sample * 0.5 + 0.5 + chN) * per_wave_len);
wave_imgMatrix_count.typedArrayView[each_imgMatrix_countRow_begin_i + cur_colI] += 1;
};
};
each_pixel_sampleGroup_begin += this.sample_nf_per_pixel;
each_imgMatrix_countRow_begin_i += this.wave_area_length;
row_i += 1;
};
this.leftedAudioData = new AudioData(
audioData.sampleRate,
audioData.channels.map(ch => ch.slice(each_pixel_sampleGroup_begin)),
audioData.audioEndTime
);
} else {
if (this.leftedAudioData) this.leftedAudioDataCyclicContainer.updatedata(this.leftedAudioData);
this.leftedAudioDataCyclicContainer.updatedata(
new AudioData(
audioData.sampleRate,
audioData.channels.map(ch => ch.slice(each_pixel_sampleGroup_begin)),
audioData.audioEndTime,
),
);
this.leftedAudioData = null;
};
if (!wave_image_length) return;
const imageData = new ImageData(this.wave_area_length, wave_image_length);
for (let i = 0; i < wave_imgMatrix_count.typedArrayView.length; i += 1) {
const p = i * 4;
const cur_pixel = circle_one(wave_imgMatrix_count.typedArrayView[i] / this.sample_nf_per_pixel);
const color = num2color(cur_pixel);
imageData.data[p + 0] = color[0]; // R value
imageData.data[p + 1] = color[1]; // G value
imageData.data[p + 2] = color[2]; // B value
imageData.data[p + 3] = color[3]; // A value
};
return imageData;
};
/**
*
* @param {AudioData} data 具有如下格式的对象:
* {
* sampleRate: Number, 音频采样率,单位Hz
* channels: Array[Float32Array], 数组,每个元素代表一个通道,
* 每个通道为浮点数数组。
* 每个通道长度应该相同。
* 每个通道中的每个元素为一个采样点。
* timeStamp: Date.now(), 音频末尾时间
* }
*/
updateAudioData = (audioData) => {
this._checkAudioData(audioData);
const imageData = this.audioData2imageData(audioData);
if (!imageData) return;
this.cyclicImageData.update(imageData);
this.setData(
{
cyclicImageData: this.cyclicImageData,
audioEndTime: audioData.audioEndTime,
}
);
};
draw = async ({ cyclicImageData, audioEndTime }) => {
this.canvas_ctx.clearRect(0, 0, this.canvas.width, this.canvas.height);
const imageData = cyclicImageData.toImageDataT();
this.canvas_ctx.putImageData(imageData,
this.canvas.width - imageData.width, 0,
);
if (this.show_time) {
const end_time = audioEndTime;
this.canvas_ctx.beginPath();
const dt = 0.5;
const time_dx = Math.round(dt * this.canvas.width / this.total_duration);
const s_y = this.canvas.height - 20, e_y = this.canvas.height - 10;
for (let i = 1; time_dx * i <= imageData.width; i += 1) {
this.canvas_ctx.moveTo(this.canvas.width - time_dx * i, s_y);
this.canvas_ctx.lineTo(this.canvas.width - time_dx * i, e_y);
this.canvas_ctx.fillText((end_time - dt * i).toFixed(3).toString(), this.canvas.width - time_dx * (i + 0.5), this.canvas.height);
};
this.canvas_ctx.stroke();
this.canvas_ctx.closePath();
};
};
};
export { WaveDrawer };
import { StftData } from "../../Audio/AudioContainer.js";
import { MyWorker } from "../../Workers/MyWorker.js";
// 将ASRModel封装成异步的形式,对外暴露API一致,这个封装需要结合 ".WorkerASRModelScript.js" 一起实现。
class WorkerASRModel extends MyWorker {
constructor() {
const WebWorkScriptURL = './ASR/Model/WorkerASRModelScript.js';
super(WebWorkScriptURL)
this.reciveData('SetProperties', (properties) => {
for (let prop_name in properties) {
this[prop_name] = properties[prop_name];
};
});
};
init = async (ModelDir = './ASR/Model/Network/tensorflowjs/tfjsModel/tfjs_mobilev3small_thchs30/') => {
await this.CreatePromise;
return await this.executeAsyncWorkerFunction('init', ModelDir);
};
predictAudioData = (audioData) => {
return this.executeAsyncWorkerFunction('predictAudioData', audioData);
};
predictStftData = (stftData) => {
const { dataContent, transferList } = this.getStftData2Transfer(stftData);
return this.executeAsyncWorkerFunction('predictStftData', dataContent);
};
/**
*
* @param {StftData} stftData_Clip
*/
getStftData2Transfer = (stftData_Clip) => {
return {
dataContent: {
sampleRate: stftData_Clip.sampleRate,
fft_n: stftData_Clip.fft_n,
hop_n: stftData_Clip.hop_n,
stft: {
stftMartrixArrayBuffer: stftData_Clip.stft.arrayBuffer,
stftMartrixHeight: stftData_Clip.stft.rowsN,
stftMartrixWidth: stftData_Clip.stft.columnsN,
},
audioEndTime: stftData_Clip.audioEndTime,
},
transferList: [stftData_Clip.stft.arrayBuffer]
};
};
};
export { WorkerASRModel };
完整的源代码可以登录【华纳网】下载。
https://www.worldwarner.com/
免责声明:本文仅代表作者个人观点,与华纳网无关。其原创性以及文中陈述文字和内容未经本站证实,对本文以及其中全部或者部分内容、文字的真实性、完整性、及时性本站不作任何保证或承诺,请读者仅作参考,并请自行核实相关内容。