(1.0.2) added get_voice() method and added handling for if the request code of get_tts() is between 400 and 600
This commit is contained in:
parent
8f52fda97c
commit
d585072d3e
4 changed files with 24 additions and 4 deletions
|
@ -15,6 +15,19 @@ class FloweryAPI:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.adapter = RestAdapter(config)
|
self.adapter = RestAdapter(config)
|
||||||
|
|
||||||
|
async def get_voice(self, voice_id: str) -> Voice:
|
||||||
|
"""Get a voice from the Flowery API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice_id (str): The ID of the voice
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Voice: The voice
|
||||||
|
"""
|
||||||
|
async for voice in self.get_voices():
|
||||||
|
if voice.id == voice_id:
|
||||||
|
return voice
|
||||||
|
|
||||||
async def get_voices(self) -> AsyncGenerator[Voice, None]:
|
async def get_voices(self) -> AsyncGenerator[Voice, None]:
|
||||||
"""Get a list of voices from the Flowery API
|
"""Get a list of voices from the Flowery API
|
||||||
|
|
||||||
|
@ -59,4 +72,6 @@ class FloweryAPI:
|
||||||
if voice:
|
if voice:
|
||||||
params['voice'] = voice.id if isinstance(voice, Voice) else voice
|
params['voice'] = voice.id if isinstance(voice, Voice) else voice
|
||||||
request = await self.adapter.get('/tts', params, timeout=180)
|
request = await self.adapter.get('/tts', params, timeout=180)
|
||||||
|
if request.status_code in range(400, 600):
|
||||||
|
raise ValueError(request.data['message'])
|
||||||
return request.data
|
return request.data
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
VERSION = "1.0.1"
|
VERSION = "1.0.2"
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "pyflowery"
|
name = "pyflowery"
|
||||||
version = "1.0.1"
|
version = "1.0.2"
|
||||||
description = "A Python API wrapper for the Flowery API"
|
description = "A Python API wrapper for the Flowery API"
|
||||||
authors = ["cswimr <seaswimmerthefsh@gmail.com>"]
|
authors = ["cswimr <seaswimmerthefsh@gmail.com>"]
|
||||||
license = "GPL 3.0-only"
|
license = "GPL 3.0-only"
|
||||||
|
|
|
@ -16,6 +16,10 @@ root.addHandler(handler)
|
||||||
|
|
||||||
api = FloweryAPI(FloweryAPIConfig())
|
api = FloweryAPI(FloweryAPIConfig())
|
||||||
|
|
||||||
|
ALEXANDER = "fa3ea565-121f-5efd-b4e9-59895c77df23" # TikTok
|
||||||
|
JACOB = "38f45366-68e8-5d39-b1ef-3fd4eeb61cdb" # Microsoft Azure
|
||||||
|
STORMTROOPER = "191c5adc-a092-5eea-b4ff-ce01f66153ae" # TikTok
|
||||||
|
|
||||||
async def test_get_voices():
|
async def test_get_voices():
|
||||||
"""Test the get_voices method"""
|
"""Test the get_voices method"""
|
||||||
async for voice in api.get_voices():
|
async for voice in api.get_voices():
|
||||||
|
@ -24,7 +28,8 @@ async def test_get_voices():
|
||||||
|
|
||||||
async def test_get_tts():
|
async def test_get_tts():
|
||||||
"""Test the get_tts method"""
|
"""Test the get_tts method"""
|
||||||
tts = await api.get_tts(text="BLAST HIM!", voice="191c5adc-a092-5eea-b4ff-ce01f66153ae")
|
voice = await api.get_voice(voice_id=ALEXANDER)
|
||||||
|
tts = await api.get_tts(text="Sphinx of black quartz, judge my vow. The quick brown fox jumps over a lazy dog.", voice=voice)
|
||||||
try:
|
try:
|
||||||
with open('test.mp3', 'wb') as f:
|
with open('test.mp3', 'wb') as f:
|
||||||
f.write(tts)
|
f.write(tts)
|
||||||
|
@ -34,7 +39,7 @@ async def test_get_tts():
|
||||||
try:
|
try:
|
||||||
await api.get_tts(text=long_string)
|
await api.get_tts(text=long_string)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
api.config.logger.error(e, exc_info=True)
|
api.config.logger.error("This is expected to fail:\n%s", e, exc_info=True)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
api.config.logger.info("testing get_voices")
|
api.config.logger.info("testing get_voices")
|
||||||
|
|
Loading…
Reference in a new issue