Add openpilot tests
This commit is contained in:
106
tinygrad_repo/test/extra/test_utils.py
Normal file
106
tinygrad_repo/test/extra/test_utils.py
Normal file
@@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env python
|
||||
import io, unittest
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tinygrad.helpers import CI
|
||||
from extra.utils import fetch, temp, download_file
|
||||
from tinygrad.nn.state import torch_load
|
||||
from PIL import Image
|
||||
|
||||
@unittest.skipIf(CI, "no internet tests in CI")
|
||||
class TestFetch(unittest.TestCase):
|
||||
def test_fetch_bad_http(self):
|
||||
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/500')
|
||||
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/404')
|
||||
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/400')
|
||||
|
||||
def test_fetch_small(self):
|
||||
assert(len(fetch('https://google.com'))>0)
|
||||
|
||||
def test_fetch_img(self):
|
||||
img = fetch("https://media.istockphoto.com/photos/hen-picture-id831791190")
|
||||
pimg = Image.open(io.BytesIO(img))
|
||||
assert pimg.size == (705, 1024)
|
||||
|
||||
class TestFetchRelative(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.working_dir = os.getcwd()
|
||||
self.tempdir = tempfile.TemporaryDirectory()
|
||||
os.chdir(self.tempdir.name)
|
||||
with open('test_file.txt', 'x') as f:
|
||||
f.write("12345")
|
||||
|
||||
def tearDown(self):
|
||||
os.chdir(self.working_dir)
|
||||
self.tempdir.cleanup()
|
||||
|
||||
#test ./
|
||||
def test_fetch_relative_dotslash(self):
|
||||
self.assertEqual(b'12345', fetch("./test_file.txt"))
|
||||
|
||||
#test ../
|
||||
def test_fetch_relative_dotdotslash(self):
|
||||
os.mkdir('test_file_path')
|
||||
os.chdir('test_file_path')
|
||||
self.assertEqual(b'12345', fetch("../test_file.txt"))
|
||||
|
||||
class TestDownloadFile(unittest.TestCase):
|
||||
def setUp(self):
|
||||
from pathlib import Path
|
||||
self.test_file = Path(temp("test_download_file/test_file.txt"))
|
||||
|
||||
def tearDown(self):
|
||||
os.remove(self.test_file)
|
||||
os.removedirs(self.test_file.parent)
|
||||
|
||||
@patch('requests.get')
|
||||
def test_download_file_with_mkdir(self, mock_requests):
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_content.return_value = [b'1234', b'5678']
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {'content-length': '8'}
|
||||
mock_requests.return_value = mock_response
|
||||
self.assertFalse(self.test_file.parent.exists())
|
||||
download_file("https://www.mock.com/fake.txt", self.test_file, skip_if_exists=False)
|
||||
self.assertTrue(self.test_file.parent.exists())
|
||||
self.assertTrue(self.test_file.is_file())
|
||||
self.assertEqual('12345678', self.test_file.read_text())
|
||||
|
||||
class TestUtils(unittest.TestCase):
|
||||
def test_fake_torch_load_zipped(self): self._test_fake_torch_load_zipped()
|
||||
def test_fake_torch_load_zipped_float16(self): self._test_fake_torch_load_zipped(isfloat16=True)
|
||||
def _test_fake_torch_load_zipped(self, isfloat16=False):
|
||||
class LayerWithOffset(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(LayerWithOffset, self).__init__()
|
||||
d = torch.randn(16)
|
||||
self.param1 = torch.nn.Parameter(
|
||||
d.as_strided([2, 2], [1, 2], storage_offset=5)
|
||||
)
|
||||
self.param2 = torch.nn.Parameter(
|
||||
d.as_strided([2, 2], [1, 2], storage_offset=4)
|
||||
)
|
||||
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(4, 8),
|
||||
torch.nn.Linear(8, 3),
|
||||
LayerWithOffset()
|
||||
)
|
||||
if isfloat16: model = model.half()
|
||||
|
||||
path = temp(f"test_load_{isfloat16}.pt")
|
||||
torch.save(model.state_dict(), path)
|
||||
model2 = torch_load(path)
|
||||
|
||||
for name, a in model.state_dict().items():
|
||||
b = model2[name]
|
||||
a, b = a.numpy(), b.numpy()
|
||||
assert a.shape == b.shape
|
||||
assert a.dtype == b.dtype
|
||||
assert np.array_equal(a, b)
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user