Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the realm-name retrieval #846

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ public void getActiveUserTest() {
DisconnectMessageModel disconnectMessageModel = new DisconnectMessageModel();
disconnectMessageModel.setClientId("clientId");
disconnectMessageModel.setUserId(USER);
disconnectMessageModel.setRealmId(REALM_RADIUS_NAME);
disconnectMessageModel.setRealmId(REALM_RADIUS_ID);
when(tableManager.getActiveSession(any(), any(), any())).thenReturn(disconnectMessageModel);
RadiusServiceModel activeUser = radiusService.getActiveUser("test", "test");
assertNotNull(activeUser);
Expand All @@ -195,7 +195,7 @@ public void getActiveUserNullTest() {
DisconnectMessageModel disconnectMessageModel = new DisconnectMessageModel();
disconnectMessageModel.setClientId("clientId");
disconnectMessageModel.setUserId(USER);
disconnectMessageModel.setRealmId(REALM_RADIUS_NAME);
disconnectMessageModel.setRealmId(REALM_RADIUS_ID);
when(tableManager.getActiveSession(any(), any(), any())).thenReturn(null);
RadiusServiceModel activeUser = radiusService.getActiveUser("test", "test");
assertNotNull(activeUser);
Expand All @@ -209,7 +209,7 @@ public void getRadiusInfoTest() {
DisconnectMessageModel disconnectMessageModel = new DisconnectMessageModel();
disconnectMessageModel.setClientId("clientId");
disconnectMessageModel.setUserId(USER);
disconnectMessageModel.setRealmId(REALM_RADIUS_NAME);
disconnectMessageModel.setRealmId(REALM_RADIUS_ID);

when(tableManager.getAllActiveSessions(any(), any()))
.thenReturn(Arrays.asList(disconnectMessageModel));
Expand All @@ -226,7 +226,7 @@ public void logoutTest() {
disconnectMessageModel.setClientId("clientId");
disconnectMessageModel.setKeycloakSessionId("sessionId");
disconnectMessageModel.setUserId(USER);
disconnectMessageModel.setRealmId(REALM_RADIUS_NAME);
disconnectMessageModel.setRealmId(REALM_RADIUS_ID);
disconnectMessageModel.setFramedIp("test");
when(tableManager.getActiveSession(any(), any(), any())).thenReturn(disconnectMessageModel);
RadiusServiceModel test = radiusService.logout("test", "test");
Expand All @@ -249,7 +249,7 @@ public void logoutwithoutIpTest() {
disconnectMessageModel.setClientId("clientId");
disconnectMessageModel.setKeycloakSessionId("sessionId");
disconnectMessageModel.setUserId(USER);
disconnectMessageModel.setRealmId(REALM_RADIUS_NAME);
disconnectMessageModel.setRealmId(REALM_RADIUS_ID);
disconnectMessageModel.setFramedIp(null);
when(tableManager.getActiveSession(any(), any(), any())).thenReturn(disconnectMessageModel);
RadiusServiceModel test = radiusService.logout("test", "test");
Expand All @@ -264,7 +264,7 @@ public void logoutSessionDoesNotExistTest() {
disconnectMessageModel.setClientId("clientId");
disconnectMessageModel.setKeycloakSessionId("sessionId");
disconnectMessageModel.setUserId(USER);
disconnectMessageModel.setRealmId(REALM_RADIUS_NAME);
disconnectMessageModel.setRealmId(REALM_RADIUS_ID);
disconnectMessageModel.setFramedIp("test");
when(tableManager.getActiveSession(any(), any(), any())).thenReturn(disconnectMessageModel);
when(userSessionProvider.getUserSession(eq(realmModel), anyString()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ private DisconnectMessageModel createDisconnectMessageModel() {
disconnectMessageModel.setKeycloakSessionId("testSession");
disconnectMessageModel.setUserId(USER);
disconnectMessageModel.setClientId(CLIENT_ID);
disconnectMessageModel.setRealmId(REALM_RADIUS_NAME);
disconnectMessageModel.setRealmId(REALM_RADIUS_ID);
disconnectMessageModel.setId("sessionId");
disconnectMessageModel.setCreatedDate(new Date(10000L));
disconnectMessageModel.setAddress("127.0.0.1");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ private static RealmModel getRealmFromUserName(KeycloakSession session,
.getValueString(), "@");
RealmModel realm = null;
if (StringUtils.isNotEmpty(realmName)) {
realm = session.realms().getRealm(realmName);
realm = session.realms().getRealmByName(realmName);
}
return (realm == null) ? getDefaultRealm(session) : realm;
}
Expand All @@ -202,7 +202,7 @@ private static RealmModel getRealm(KeycloakSession session,
for (String attribute : attributes) {
String realmName = getRealmName(attribute, radiusPacket);
if (realmName != null) {
return session.realms().getRealm(realmName);
return session.realms().getRealmByName(realmName);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public static EventBuilder createMasterEvent(
KeycloakSession session,
ClientConnection clientConnection) {

return createEvent(session, session.realms().getRealm(Config.getAdminRealm()),
return createEvent(session, session.realms().getRealmByName(Config.getAdminRealm()),
clientConnection);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
public class RadiusHelperTest extends AbstractRadiusTest {

private static final Logger LOGGER = LoggerFactory.getLogger(RadiusHelperTest.class);
private static final String SECOND_REALM_ID = "second_realm_id";
private static final String SECOND_REALM_NAME = "second_realm_name";
@Mock
private IRadiusServiceProvider radiusServiceProvider1;
@Mock
Expand All @@ -56,6 +58,14 @@ public void beforeMethods() {
.thenReturn(new AttributeType(3, "r3", "string"));
when(dictionary.getAttributeTypeByName("User-Name"))
.thenReturn(new AttributeType(4, "User-Name", "string"));
when(dictionary.getAttributeTypeByCode(-1, 1))
.thenReturn(new AttributeType(1, "realm-attribute", "string"));
when(dictionary.getAttributeTypeByCode(-1, 2))
.thenReturn(new AttributeType(2, "r", "string"));
when(dictionary.getAttributeTypeByCode(-1, 3))
.thenReturn(new AttributeType(3, "r3", "string"));
when(dictionary.getAttributeTypeByCode(-1,4))
.thenReturn(new AttributeType(4, "User-Name", "string"));
DictionaryLoader.getInstance().setWritableDictionary(realDictionary);
}

Expand Down Expand Up @@ -138,13 +148,17 @@ public void testGeneratePassword() {

@Test
public void testRealmAttributes() {
RealmModel secondRealm = mock(RealmModel.class);
when(secondRealm.getId()).thenReturn(REALM_RADIUS+"-id");
when(secondRealm.getName()).thenReturn(REALM_RADIUS);
when(realmProvider.getRealmByName(REALM_RADIUS)).thenReturn(secondRealm);
RadiusHelper.setRealmAttributes(Collections.singletonList("realm-attribute"));
RadiusPacket radiusPacket = RadiusPackets.create(dictionary, 1, 1);
radiusPacket.addAttribute("realm-attribute",
Hex.encodeHexString(REALM_RADIUS.getBytes(Charset.defaultCharset())));
radiusPacket.addAttribute("realm-attribute", REALM_RADIUS);
RealmModel realmModel = RadiusHelper.getRealm(session, radiusPacket);
assertNotNull(realmModel);
assertEquals(realmModel.getName(), REALM_RADIUS_NAME);
assertEquals(realmModel.getName(), REALM_RADIUS);
assertEquals(realmModel.getId(), REALM_RADIUS+"-id");
}

@Test
Expand All @@ -155,17 +169,18 @@ public void testRealmAttributesNullWithDefaultRealm() {
RealmModel realmModel = RadiusHelper.getRealm(session, radiusPacket);
assertNotNull(realmModel);
assertEquals(realmModel.getName(), REALM_RADIUS_NAME);
assertEquals(realmModel.getId(), REALM_RADIUS_ID);
}

@Test(expectedExceptions = IllegalStateException.class,
expectedExceptionsMessageRegExp =
"Found more than one Radius Realm \\(RadiusName, second_realm\\). " +
"Found more than one Radius Realm \\(" + REALM_RADIUS_NAME + ", " + SECOND_REALM_NAME + "\\). " +
"If you expect to use the Default Realm," +
" than you should use only one realm with radius client")
public void testRealmAttributesNullWith2DefaultRealm() {
RealmModel secondRealm = mock(RealmModel.class);
when(secondRealm.getId()).thenReturn("second_realm");
when(secondRealm.getName()).thenReturn("second_realm");
when(secondRealm.getId()).thenReturn(SECOND_REALM_ID);
when(secondRealm.getName()).thenReturn(SECOND_REALM_NAME);
ClientModel secondClientModel = mock(ClientModel.class);
when(secondClientModel.getProtocol()).thenReturn(RadiusLoginProtocolFactory.RADIUS_PROTOCOL);
when(secondRealm.getClientsStream()).thenAnswer(i -> Stream.of(secondClientModel));
Expand Down Expand Up @@ -230,36 +245,42 @@ public void testGetRealm() {
@Test
public void testRealmInUserName() {
RealmModel secondRealm = mock(RealmModel.class);
when(secondRealm.getId()).thenReturn("second_realm");
when(secondRealm.getName()).thenReturn("second_realm");
when(secondRealm.getId()).thenReturn(SECOND_REALM_ID);
when(secondRealm.getName()).thenReturn(SECOND_REALM_NAME);
ClientModel secondClientModel = mock(ClientModel.class);
when(secondClientModel.getProtocol())
.thenReturn(RadiusLoginProtocolFactory.RADIUS_PROTOCOL);
when(secondRealm.getClientsStream()).thenAnswer(i -> Stream.of(secondClientModel));
when(realmProvider.getRealmsStream()).thenAnswer(i -> Stream.of(realmModel, secondRealm));
when(realmProvider.getRealmByName(SECOND_REALM_NAME)).thenReturn(secondRealm);
RadiusHelper.setRealmAttributes(Collections.emptyList());
RadiusPacket radiusPacket = RadiusPackets.create(realDictionary, 1, 1);
radiusPacket.addAttribute("User-Name", USER + "@second_realm");
radiusPacket.addAttribute("User-Name", USER + "@" + SECOND_REALM_NAME);
RealmModel realmModel = RadiusHelper.getRealm(session, radiusPacket);
assertNotNull(realmModel);
assertEquals(realmModel.getName(), SECOND_REALM_NAME);
assertEquals(realmModel.getId(), SECOND_REALM_ID);
}

@Test
public void testRealmInUserEmail() {
when(userProvider.getUserByUsername(realmModel, USER)).thenReturn(null);
RealmModel secondRealm = mock(RealmModel.class);
when(secondRealm.getId()).thenReturn("second_realm");
when(secondRealm.getName()).thenReturn("second_realm");
when(secondRealm.getId()).thenReturn(SECOND_REALM_ID);
when(secondRealm.getName()).thenReturn(SECOND_REALM_NAME);
ClientModel secondClientModel = mock(ClientModel.class);
when(secondClientModel.getProtocol())
.thenReturn(RadiusLoginProtocolFactory.RADIUS_PROTOCOL);
when(secondRealm.getClientsStream()).thenAnswer(i -> Stream.of(secondClientModel));
when(realmProvider.getRealmsStream()).thenAnswer(i -> Stream.of(realmModel, secondRealm));
when(realmProvider.getRealmByName(SECOND_REALM_NAME)).thenReturn(secondRealm);
RadiusHelper.setRealmAttributes(Collections.emptyList());
RadiusPacket radiusPacket = RadiusPackets.create(realDictionary, 1, 1);
radiusPacket.addAttribute("User-Name", USER + "@second_realm");
radiusPacket.addAttribute("User-Name", USER + "@" + SECOND_REALM_NAME);
RealmModel realmModel = RadiusHelper.getRealm(session, radiusPacket);
assertNotNull(realmModel);
assertEquals(realmModel.getName(), SECOND_REALM_NAME);
assertEquals(realmModel.getId(), SECOND_REALM_ID);
}

@Test()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ public void testIsValid() {
public void testIsNotValid() {
// request.addAttribute(REALM_RADIUS, "33");
reset(realmProvider);
when(realmProvider.getRealm(Config.getAdminRealm())).thenReturn(realmModel);
when(realmProvider.getRealmByName(Config.getAdminRealm())).thenReturn(realmModel);
PAPProtocol papProtocol = new PAPProtocol(request, session);
assertFalse(papProtocol.isValid(new InetSocketAddress(0)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public void testgetRadiusPasswordsDisabledUser() {
@Test
public void testgetRadiusPasswordsRealmDoesNotExists() {
RealmProvider provider = getProvider(RealmProvider.class);
when(provider.getRealm(REALM_RADIUS_NAME)).thenReturn(null);
when(provider.getRealm(REALM_RADIUS_ID)).thenReturn(null);
when(authProtocol.getRealm()).thenReturn(null);
assertFalse(authRequestInitialization
.init(inetSocketAddress, USER, authProtocol, session));
Expand Down Expand Up @@ -163,7 +163,7 @@ public void testgetafterAuthEROOR() {
@Test
public void testgetafterAuthRealmERROR() {
RealmProvider realmProvider = getProvider(RealmProvider.class);
when(realmProvider.getRealm(REALM_RADIUS_NAME)).thenReturn(null);
when(realmProvider.getRealm(REALM_RADIUS_ID)).thenReturn(null);
authRequestInitialization
.afterAuth(4, session);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ public class KeycloakSessionUtilsTest extends AbstractRadiusTest {
public void getIsActiveSession() {
when(userSessionProvider.getUserSession(eq(realmModel), anyString()))
.thenReturn(null);
assertFalse(KeycloakSessionUtils.isActiveSession(session, "test", REALM_RADIUS_NAME));
assertFalse(KeycloakSessionUtils.isActiveSession(session, "test", REALM_RADIUS_ID));
}

@Test
public void getRadiusInfo() {
assertNotNull(KeycloakSessionUtils.getRadiusUserInfo(session));
assertNotNull(KeycloakSessionUtils.getRadiusSessionInfo(session));
assertTrue(KeycloakSessionUtils.isActiveSession(session, "test", REALM_RADIUS_NAME));
assertTrue(KeycloakSessionUtils.isActiveSession(session, "test", REALM_RADIUS_ID));
when(session.getAttribute(anyString(), any())).thenReturn(null);
assertNull(KeycloakSessionUtils.getRadiusSessionInfo(session));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.github.vzakharchenko.radius.radius.holder.IRadiusUserInfoGetter;
import com.github.vzakharchenko.radius.radius.server.KeycloakRadiusServer;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.keycloak.Config;
import org.keycloak.authorization.AuthorizationProvider;
import org.keycloak.authorization.policy.evaluation.PolicyEvaluator;
import org.keycloak.authorization.store.ResourceServerStore;
Expand Down Expand Up @@ -62,6 +63,7 @@ public abstract class AbstractRadiusTest {

public static final String RADIUS_SESSION_ID = "testSessionId";
public static final String REALM_RADIUS = "realm-radius";
public static final String REALM_RADIUS_ID = "01234567-89ab-cdef-0123-456789abcdef";
public static final String REALM_RADIUS_NAME = "RadiusName";
public static final String CLIENT_ID = "CLIENT_ID";
public static final String USER = "USER";
Expand Down Expand Up @@ -270,12 +272,13 @@ public void beforeRadiusMethod() {
when(clientModel.getId()).thenReturn(CLIENT_ID);
when(clientModel.getProtocol()).thenReturn(RadiusLoginProtocolFactory.RADIUS_PROTOCOL);
when(clientModel.isEnabled()).thenReturn(true);
when(realmProvider.getRealm(REALM_RADIUS_NAME)).thenReturn(realmModel);
when(realmProvider.getRealm(anyString())).thenReturn(realmModel);
when(realmProvider.getRealm(REALM_RADIUS_ID)).thenReturn(realmModel);
when(realmProvider.getRealmByName(REALM_RADIUS_NAME)).thenReturn(realmModel);
when(realmProvider.getRealmByName(Config.getAdminRealm())).thenReturn(realmModel);
when(realmProvider.getRealmsStream()).thenAnswer(i -> Stream.of(realmModel));
when(realmModel.getClientByClientId(CLIENT_ID)).thenReturn(clientModel);
when(realmModel.getName()).thenReturn(REALM_RADIUS_NAME);
when(realmModel.getId()).thenReturn(REALM_RADIUS_NAME);
when(realmModel.getId()).thenReturn(REALM_RADIUS_ID);
when(realmModel.isEventsEnabled()).thenReturn(false);
when(realmModel.getAttributes()).thenReturn(new HashMap<>());
when(realmModel.getClientsStream()).thenAnswer(i -> Stream.of(clientModel));
Expand Down